diff --git a/ACKNOWLEDGMENTS .md b/ACKNOWLEDGMENTS.md
similarity index 100%
rename from ACKNOWLEDGMENTS .md
rename to ACKNOWLEDGMENTS.md
diff --git a/conftest.py b/conftest.py
index 08641ff36543dbfba7109f84616ead8d2b472891..abb0c0f52757e4b2228d7d48e3dc07e08b302841 100644
--- a/conftest.py
+++ b/conftest.py
@@ -58,10 +58,13 @@ def default_session_fixture(request):
     :type request: _pytest.python.SubRequest
     :return:
     """
-    patched = mock.patch("multiprocessing.cpu_count", return_value=1)
-    patched.__enter__()
+    # patched = mock.patch("multiprocessing.cpu_count", return_value=1)
+    # patched.__enter__()
 
-    def unpatch():
-        patched.__exit__()
+    # def unpatch():
+    #    patched.__exit__()
 
-    request.addfinalizer(unpatch)
+    # request.addfinalizer(unpatch)
+
+    with mock.patch("multiprocessing.cpu_count", return_value=1):
+        yield
diff --git a/mlair/model_modules/__init__.py b/mlair/model_modules/__init__.py
index ea2067bdfdaacb6290157be681786212b0422812..96c92108ca4cbe63460722db12e61d9343593b06 100644
--- a/mlair/model_modules/__init__.py
+++ b/mlair/model_modules/__init__.py
@@ -1,3 +1,3 @@
 """Collection of all modules that are related to a model."""
 
-from .model_class import AbstractModelClass
+from .abstract_model_class import AbstractModelClass
diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..894ff7ac4e787a8b31f75ff932f60bec8c561094
--- /dev/null
+++ b/mlair/model_modules/abstract_model_class.py
@@ -0,0 +1,241 @@
+import inspect
+from abc import ABC
+from typing import Any, Dict, Callable
+
+import keras
+import tensorflow as tf
+
+from mlair.helpers import remove_items
+
+
+class AbstractModelClass(ABC):
+    """
+    The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow.
+
+    The model can always be accessed by calling ModelClass.model or directly by an model method without parsing the
+    model attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides
+    the corresponding loss function.
+    """
+
+    _requirements = []
+
+    def __init__(self, input_shape, output_shape) -> None:
+        """Predefine internal attributes for model and loss."""
+        self.__model = None
+        self.model_name = self.__class__.__name__
+        self.__custom_objects = {}
+        self.__allowed_compile_options = {'optimizer': None,
+                                          'loss': None,
+                                          'metrics': None,
+                                          'loss_weights': None,
+                                          'sample_weight_mode': None,
+                                          'weighted_metrics': None,
+                                          'target_tensors': None
+                                          }
+        self.__compile_options = self.__allowed_compile_options
+        self.__compile_options_is_set = False
+        self._input_shape = input_shape
+        self._output_shape = self.__extract_from_tuple(output_shape)
+
+    def __getattr__(self, name: str) -> Any:
+        """
+        Is called if __getattribute__ is not able to find requested attribute.
+
+        Normally, the model class is saved into a variable like `model = ModelClass()`. To bypass a call like
+        `model.model` to access the _model attribute, this method tries to search for the named attribute in the
+        self.model namespace and returns this attribute if available. Therefore, following expression is true:
+        `ModelClass().compile == ModelClass().model.compile` as long the called attribute/method is not part if the
+        ModelClass itself.
+
+        :param name: name of the attribute or method to call
+
+        :return: attribute or method from self.model namespace
+        """
+        return self.model.__getattribute__(name)
+
+    @property
+    def model(self) -> keras.Model:
+        """
+        The model property containing a keras.Model instance.
+
+        :return: the keras model
+        """
+        return self.__model
+
+    @model.setter
+    def model(self, value):
+        self.__model = value
+
+    @property
+    def custom_objects(self) -> Dict:
+        """
+        The custom objects property collects all non-keras utilities that are used in the model class.
+
+        To load such a customised and already compiled model (e.g. from local disk), this information is required.
+
+        :return: custom objects in a dictionary
+        """
+        return self.__custom_objects
+
+    @custom_objects.setter
+    def custom_objects(self, value) -> None:
+        self.__custom_objects = value
+
+    @property
+    def compile_options(self) -> Callable:
+        """
+        The compile options property allows the user to use all keras.compile() arguments. They can ether be passed as
+        dictionary (1), as attribute, without setting compile_options (2) or as mixture (partly defined as instance
+        attributes and partly parsing a dictionary) of both of them (3).
+        The method will raise an Error when the same parameter is set differently.
+
+        Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile)
+        .. code-block:: python
+            def set_compile_options(self):
+                self.compile_options = {"optimizer": keras.optimizers.SGD(),
+                                        "loss": keras.losses.mean_squared_error,
+                                        "metrics": ["mse", "mae"]}
+
+        Example (2)
+        .. code-block:: python
+            def set_compile_options(self):
+                self.optimizer = keras.optimizers.SGD()
+                self.loss = keras.losses.mean_squared_error
+                self.metrics = ["mse", "mae"]
+
+        Example (3)
+        Correct:
+        .. code-block:: python
+            def set_compile_options(self):
+                self.optimizer = keras.optimizers.SGD()
+                self.loss = keras.losses.mean_squared_error
+                self.compile_options = {"metrics": ["mse", "mae"]}
+
+        Incorrect: (Will raise an error)
+        .. code-block:: python
+            def set_compile_options(self):
+                self.optimizer = keras.optimizers.SGD()
+                self.loss = keras.losses.mean_squared_error
+                self.compile_options = {"optimizer" = keras.optimizers.Adam(), "metrics": ["mse", "mae"]}
+
+        Note:
+        * As long as the attribute and the dict value have exactly the same values, the setter method will not raise
+        an error
+        * For example (2) there is no check implemented, if the attributes are valid compile options
+
+
+        :return:
+        """
+        if self.__compile_options_is_set is False:
+            self.compile_options = None
+        return self.__compile_options
+
+    @compile_options.setter
+    def compile_options(self, value: Dict) -> None:
+        if isinstance(value, dict):
+            if not (set(value.keys()) <= set(self.__allowed_compile_options.keys())):
+                raise ValueError(f"Got invalid key for compile_options. {value.keys()}")
+
+        for allow_k in self.__allowed_compile_options.keys():
+            if hasattr(self, allow_k):
+                new_v_attr = getattr(self, allow_k)
+            else:
+                new_v_attr = None
+            if isinstance(value, dict):
+                new_v_dic = value.pop(allow_k, None)
+            elif value is None:
+                new_v_dic = None
+            else:
+                raise TypeError(f"`compile_options' must be `dict' or `None', but is {type(value)}.")
+            if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or (
+                    (new_v_attr is None) ^ (new_v_dic is None)):
+                if new_v_attr is not None:
+                    self.__compile_options[allow_k] = new_v_attr
+                else:
+                    self.__compile_options[allow_k] = new_v_dic
+
+            else:
+                raise ValueError(
+                    f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}")
+        self.__compile_options_is_set = True
+
+    @staticmethod
+    def __extract_from_tuple(tup):
+        """Return element of tuple if it contains only a single element."""
+        return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup
+
+    @staticmethod
+    def __compare_keras_optimizers(first, second):
+        """
+        Compares if optimiser and all settings of the optimisers are exactly equal.
+
+        :return True if optimisers are interchangeable, or False if optimisers are distinguishable.
+        """
+        if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers':
+            res = True
+            init = tf.global_variables_initializer()
+            with tf.Session() as sess:
+                sess.run(init)
+                for k, v in first.__dict__.items():
+                    try:
+                        res *= sess.run(v) == sess.run(second.__dict__[k])
+                    except TypeError:
+                        res *= v == second.__dict__[k]
+        else:
+            res = False
+        return bool(res)
+
+    def get_settings(self) -> Dict:
+        """
+        Get all class attributes that are not protected in the AbstractModelClass as dictionary.
+
+        :return: all class attributes
+        """
+        return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))
+
+    def set_model(self):
+        """Abstract method to set model."""
+        raise NotImplementedError
+
+    def set_compile_options(self):
+        """
+        This method only has to be defined in child class, when additional compile options should be used ()
+        (other options than optimizer and loss)
+        Has to be set as dictionary: {'optimizer': None,
+                                      'loss': None,
+                                      'metrics': None,
+                                      'loss_weights': None,
+                                      'sample_weight_mode': None,
+                                      'weighted_metrics': None,
+                                      'target_tensors': None
+                                      }
+
+        :return:
+        """
+        raise NotImplementedError
+
+    def set_custom_objects(self, **kwargs) -> None:
+        """
+        Set custom objects that are not part of keras framework.
+
+        These custom objects are needed if an already compiled model is loaded from disk. There is a special treatment
+        for the Padding2D class, which is a base class for different padding types. For a correct behaviour, all
+        supported subclasses are added as custom objects in addition to the given ones.
+
+        :param kwargs: all custom objects, that should be saved
+        """
+        if "Padding2D" in kwargs.keys():
+            kwargs.update(kwargs["Padding2D"].allowed_paddings)
+        self.custom_objects = kwargs
+
+    @classmethod
+    def requirements(cls):
+        """Return requirements and own arguments without duplicates."""
+        return list(set(cls._requirements + cls.own_args()))
+
+    @classmethod
+    def own_args(cls, *args):
+        """Return all arguments (including kwonlyargs)."""
+        arg_spec = inspect.getfullargspec(cls)
+        list_of_args = arg_spec.args + arg_spec.kwonlyargs
+        return remove_items(list_of_args, ["self"] + list(args))
diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..940c9846bbaff1a3e3664169cf46de4f177169bc
--- /dev/null
+++ b/mlair/model_modules/fully_connected_networks.py
@@ -0,0 +1,134 @@
+__author__ = "Lukas Leufen"
+__date__ = '2021-02-'
+
+from functools import reduce, partial
+
+from mlair.model_modules import AbstractModelClass
+from mlair.helpers import select_from_dict
+
+import keras
+
+
+class FCN_64_32_16(AbstractModelClass):
+    """
+    A customised model 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending
+    on the window_lead_time parameter.
+    """
+
+    def __init__(self, input_shape: list, output_shape: list):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+        """
+
+        assert len(input_shape) == 1
+        assert len(output_shape) == 1
+        super().__init__(input_shape[0], output_shape[0])
+
+        # settings
+        self.activation = keras.layers.PReLU
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options['loss'])
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.Flatten()(x_input)
+        x_in = keras.layers.Dense(64, name="Dense_64")(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(32, name="Dense_32")(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(16, name="Dense_16")(x_in)
+        x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(self._output_shape, name="Dense_output")(x_in)
+        out_main = self.activation()(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out_main])
+
+    def set_compile_options(self):
+        self.optimizer = keras.optimizers.adam(lr=1e-2)
+        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
+
+
+class FCN(AbstractModelClass):
+    """
+    A customisable fully connected network (64, 32, 16, window_lead_time), where the last layer is the output layer depending
+    on the window_lead_time parameter.
+    """
+
+    _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
+                   "sigmoid": partial(keras.layers.Activation, "sigmoid")}
+    _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD}
+    _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov"]
+
+    def __init__(self, input_shape: list, output_shape: list, activation="relu", optimizer="adam",
+                 n_layer=1, n_hidden=10, **kwargs):
+        """
+        Sets model and loss depending on the given arguments.
+
+        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
+        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+        """
+
+        assert len(input_shape) == 1
+        assert len(output_shape) == 1
+        super().__init__(input_shape[0], output_shape[0])
+
+        # settings
+        self.activation = self._set_activation(activation)
+        self.optimizer = self._set_optimizer(optimizer, **kwargs)
+        self.layer_configuration = (n_layer, n_hidden)
+        self._update_model_name()
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        # self.set_custom_objects(loss=self.compile_options['loss'])
+
+    def _set_activation(self, activation):
+        try:
+            return self._activation.get(activation.lower())
+        except KeyError:
+            raise AttributeError(f"Given activation {activation} is not supported in this model class.")
+
+    def _set_optimizer(self, optimizer, **kwargs):
+        try:
+            opt_name = optimizer.lower()
+            opt = self._optimizer.get(opt_name)
+            opt_kwargs = {}
+            if opt_name == "adam":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
+            elif opt_name == "sgd":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
+            return opt(**opt_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
+
+    def _update_model_name(self):
+        n_layer, n_hidden = self.layer_configuration
+        n_input = str(reduce(lambda x, y: x * y, self._input_shape))
+        n_output = str(self._output_shape)
+        self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output])
+
+    def set_model(self):
+        """
+        Build the model.
+        """
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.Flatten()(x_input)
+        n_layer, n_hidden = self.layer_configuration
+        for layer in range(n_layer):
+            x_in = keras.layers.Dense(n_hidden)(x_in)
+            x_in = self.activation()(x_in)
+        x_in = keras.layers.Dense(self._output_shape)(x_in)
+        out = self.activation()(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out])
+
+    def set_compile_options(self):
+        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index a2eda6e8287af2ce489bf75b02d7b205549ff144..f8e3a21a81351ac614e2275749bb85fa82a96e02 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -120,287 +120,14 @@ import mlair.model_modules.keras_extensions
 __author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2020-05-12'
 
-from abc import ABC
-from typing import Any, Callable, Dict
-
 import keras
-import tensorflow as tf
+
+from mlair.model_modules import AbstractModelClass
 from mlair.model_modules.inception_model import InceptionModelBase
 from mlair.model_modules.flatten import flatten_tail
 from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D
 
 
-class AbstractModelClass(ABC):
-    """
-    The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow.
-
-    The model can always be accessed by calling ModelClass.model or directly by an model method without parsing the
-    model attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides
-    the corresponding loss function.
-    """
-
-    def __init__(self, input_shape, output_shape) -> None:
-        """Predefine internal attributes for model and loss."""
-        self.__model = None
-        self.model_name = self.__class__.__name__
-        self.__custom_objects = {}
-        self.__allowed_compile_options = {'optimizer': None,
-                                          'loss': None,
-                                          'metrics': None,
-                                          'loss_weights': None,
-                                          'sample_weight_mode': None,
-                                          'weighted_metrics': None,
-                                          'target_tensors': None
-                                          }
-        self.__compile_options = self.__allowed_compile_options
-        self.__compile_options_is_set = False
-        self._input_shape = input_shape
-        self._output_shape = self.__extract_from_tuple(output_shape)
-
-    def __getattr__(self, name: str) -> Any:
-        """
-        Is called if __getattribute__ is not able to find requested attribute.
-
-        Normally, the model class is saved into a variable like `model = ModelClass()`. To bypass a call like
-        `model.model` to access the _model attribute, this method tries to search for the named attribute in the
-        self.model namespace and returns this attribute if available. Therefore, following expression is true:
-        `ModelClass().compile == ModelClass().model.compile` as long the called attribute/method is not part if the
-        ModelClass itself.
-
-        :param name: name of the attribute or method to call
-
-        :return: attribute or method from self.model namespace
-        """
-        return self.model.__getattribute__(name)
-
-    @property
-    def model(self) -> keras.Model:
-        """
-        The model property containing a keras.Model instance.
-
-        :return: the keras model
-        """
-        return self.__model
-
-    @model.setter
-    def model(self, value):
-        self.__model = value
-
-    @property
-    def custom_objects(self) -> Dict:
-        """
-        The custom objects property collects all non-keras utilities that are used in the model class.
-
-        To load such a customised and already compiled model (e.g. from local disk), this information is required.
-
-        :return: custom objects in a dictionary
-        """
-        return self.__custom_objects
-
-    @custom_objects.setter
-    def custom_objects(self, value) -> None:
-        self.__custom_objects = value
-
-    @property
-    def compile_options(self) -> Callable:
-        """
-        The compile options property allows the user to use all keras.compile() arguments. They can ether be passed as
-        dictionary (1), as attribute, without setting compile_options (2) or as mixture (partly defined as instance
-        attributes and partly parsing a dictionary) of both of them (3).
-        The method will raise an Error when the same parameter is set differently.
-
-        Example (1) Recommended (includes check for valid keywords which are used as args in keras.compile)
-        .. code-block:: python
-            def set_compile_options(self):
-                self.compile_options = {"optimizer": keras.optimizers.SGD(),
-                                        "loss": keras.losses.mean_squared_error,
-                                        "metrics": ["mse", "mae"]}
-
-        Example (2)
-        .. code-block:: python
-            def set_compile_options(self):
-                self.optimizer = keras.optimizers.SGD()
-                self.loss = keras.losses.mean_squared_error
-                self.metrics = ["mse", "mae"]
-
-        Example (3)
-        Correct:
-        .. code-block:: python
-            def set_compile_options(self):
-                self.optimizer = keras.optimizers.SGD()
-                self.loss = keras.losses.mean_squared_error
-                self.compile_options = {"metrics": ["mse", "mae"]}
-
-        Incorrect: (Will raise an error)
-        .. code-block:: python
-            def set_compile_options(self):
-                self.optimizer = keras.optimizers.SGD()
-                self.loss = keras.losses.mean_squared_error
-                self.compile_options = {"optimizer" = keras.optimizers.Adam(), "metrics": ["mse", "mae"]}
-
-        Note:
-        * As long as the attribute and the dict value have exactly the same values, the setter method will not raise
-        an error
-        * For example (2) there is no check implemented, if the attributes are valid compile options
-
-
-        :return:
-        """
-        if self.__compile_options_is_set is False:
-            self.compile_options = None
-        return self.__compile_options
-
-    @compile_options.setter
-    def compile_options(self, value: Dict) -> None:
-        if isinstance(value, dict):
-            if not (set(value.keys()) <= set(self.__allowed_compile_options.keys())):
-                raise ValueError(f"Got invalid key for compile_options. {value.keys()}")
-
-        for allow_k in self.__allowed_compile_options.keys():
-            if hasattr(self, allow_k):
-                new_v_attr = getattr(self, allow_k)
-            else:
-                new_v_attr = None
-            if isinstance(value, dict):
-                new_v_dic = value.pop(allow_k, None)
-            elif value is None:
-                new_v_dic = None
-            else:
-                raise TypeError(f"`compile_options' must be `dict' or `None', but is {type(value)}.")
-            if (new_v_attr == new_v_dic or self.__compare_keras_optimizers(new_v_attr, new_v_dic)) or (
-                    (new_v_attr is None) ^ (new_v_dic is None)):
-                if new_v_attr is not None:
-                    self.__compile_options[allow_k] = new_v_attr
-                else:
-                    self.__compile_options[allow_k] = new_v_dic
-
-            else:
-                raise ValueError(
-                    f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}")
-        self.__compile_options_is_set = True
-
-    @staticmethod
-    def __extract_from_tuple(tup):
-        """Return element of tuple if it contains only a single element."""
-        return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup
-
-    @staticmethod
-    def __compare_keras_optimizers(first, second):
-        """
-        Compares if optimiser and all settings of the optimisers are exactly equal.
-
-        :return True if optimisers are interchangeable, or False if optimisers are distinguishable.
-        """
-        if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers':
-            res = True
-            init = tf.global_variables_initializer()
-            with tf.Session() as sess:
-                sess.run(init)
-                for k, v in first.__dict__.items():
-                    try:
-                        res *= sess.run(v) == sess.run(second.__dict__[k])
-                    except TypeError:
-                        res *= v == second.__dict__[k]
-        else:
-            res = False
-        return bool(res)
-
-    def get_settings(self) -> Dict:
-        """
-        Get all class attributes that are not protected in the AbstractModelClass as dictionary.
-
-        :return: all class attributes
-        """
-        return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__"))
-
-    def set_model(self):
-        """Abstract method to set model."""
-        raise NotImplementedError
-
-    def set_compile_options(self):
-        """
-        This method only has to be defined in child class, when additional compile options should be used ()
-        (other options than optimizer and loss)
-        Has to be set as dictionary: {'optimizer': None,
-                                      'loss': None,
-                                      'metrics': None,
-                                      'loss_weights': None,
-                                      'sample_weight_mode': None,
-                                      'weighted_metrics': None,
-                                      'target_tensors': None
-                                      }
-
-        :return:
-        """
-        raise NotImplementedError
-
-    def set_custom_objects(self, **kwargs) -> None:
-        """
-        Set custom objects that are not part of keras framework.
-
-        These custom objects are needed if an already compiled model is loaded from disk. There is a special treatment
-        for the Padding2D class, which is a base class for different padding types. For a correct behaviour, all
-        supported subclasses are added as custom objects in addition to the given ones.
-
-        :param kwargs: all custom objects, that should be saved
-        """
-        if "Padding2D" in kwargs.keys():
-            kwargs.update(kwargs["Padding2D"].allowed_paddings)
-        self.custom_objects = kwargs
-
-
-class MyLittleModel(AbstractModelClass):
-    """
-    A customised model 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending
-    on the window_lead_time parameter.
-    """
-
-    def __init__(self, input_shape: list, output_shape: list):
-        """
-        Sets model and loss depending on the given arguments.
-
-        :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
-        :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
-        """
-
-        assert len(input_shape) == 1
-        assert len(output_shape) == 1
-        super().__init__(input_shape[0], output_shape[0])
-
-        # settings
-        self.dropout_rate = 0.1
-        self.regularizer = keras.regularizers.l2(0.1)
-        self.activation = keras.layers.PReLU
-
-        # apply to model
-        self.set_model()
-        self.set_compile_options()
-        self.set_custom_objects(loss=self.compile_options['loss'])
-
-    def set_model(self):
-        """
-        Build the model.
-        """
-        x_input = keras.layers.Input(shape=self._input_shape)
-        x_in = keras.layers.Flatten(name='{}'.format("major"))(x_input)
-        x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in)
-        x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in)
-        x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in)
-        x_in = self.activation()(x_in)
-        x_in = keras.layers.Dense(self._output_shape, name='{}_Dense'.format("major"))(x_in)
-        out_main = self.activation()(x_in)
-        self.model = keras.Model(inputs=x_input, outputs=[out_main])
-
-    def set_compile_options(self):
-        self.initial_lr = 1e-2
-        self.optimizer = keras.optimizers.adam(lr=self.initial_lr)
-        # self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
-        #                                                                        epochs_drop=10)
-        self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]}
-
-
 class MyLittleModelHourly(AbstractModelClass):
     """
     A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the
@@ -750,8 +477,3 @@ class MyPaperModel(AbstractModelClass):
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
         self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error],
                                 "metrics": ['mse', 'mae']}
-
-
-if __name__ == "__main__":
-    model = MyLittleModel([(1, 3, 10)], [2])
-    print(model.compile_options)
diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py
index 09f49c848675eb21bd172e40b09b265e47c443fb..9cad9fd0ee2b9f3d81bd91810abcd4f6eeefb05f 100644
--- a/mlair/plotting/training_monitoring.py
+++ b/mlair/plotting/training_monitoring.py
@@ -85,8 +85,9 @@ class PlotModelHistory:
         :param filename: name (including total path) of the plot to save.
         """
         ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7)
+        ax.set_yscale('log')
         if len(self._additional_columns) > 0:
-            self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax)
+            self._data[self._additional_columns].plot(linewidth=0.7, secondary_y=True, ax=ax, logy=True)
         title = f"Model {self._plot_metric}: best = {self._data[[f'val_{self._plot_metric}']].min().values}"
         ax.set(xlabel="epoch", ylabel=self._plot_metric, title=title)
         ax.axhline(y=0, color="gray", linewidth=0.5)
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index af540fc296f1d4b707b5373fdbcbb14dac1afc7f..30672ecc9206319896205d886157b2f2f8977f39 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -20,7 +20,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
-from mlair.model_modules.model_class import MyLittleModel as VanillaModel
+from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
 
 
 class ExperimentSetup(RunEnvironment):
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index dda18fac5d8546c6e399334f3d89415d246a1975..5dd73d50f711387a65a9bc7e4daa7c1d430bfb26 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -56,7 +56,6 @@ class ModelSetup(RunEnvironment):
         """Initialise and run model setup."""
         super().__init__()
         self.model = None
-        # path = self.data_store.get("experiment_path")
         exp_name = self.data_store.get("experiment_name")
         path = self.data_store.get("model_path")
         self.scope = "model"
@@ -138,9 +137,9 @@ class ModelSetup(RunEnvironment):
 
     def build_model(self):
         """Build model using input and output shapes from data store."""
-        args_list = ["input_shape", "output_shape"]
-        args = self.data_store.create_args_dict(args_list, self.scope)
         model = self.data_store.get("model_class")
+        args_list = model.requirements()
+        args = self.data_store.create_args_dict(args_list, self.scope)
         self.model = model(**args)
         self.get_model_settings()
 
@@ -170,6 +169,7 @@ class ModelSetup(RunEnvironment):
     def report_model(self):
         model_settings = self.model.get_settings()
         model_settings.update(self.model.compile_options)
+        model_settings.update(self.model.optimizer.get_config())
         df = pd.DataFrame(columns=["model setting"])
         for k, v in model_settings.items():
             if v is None:
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 127066b87cdf507519836ff916681e4885053c09..6f78a03d67a0698274eb4795bc8941c590386063 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -17,9 +17,9 @@ from mlair.data_handler import BootStraps, KerasIterator
 from mlair.helpers.datastore import NameNotFoundInDataStore
 from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list
 from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
-from mlair.model_modules.model_class import AbstractModelClass
+from mlair.model_modules import AbstractModelClass
 from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
-    PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram,  \
+    PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \
     PlotConditionalQuantiles, PlotSeparationOfScales
 from mlair.run_modules.run_environment import RunEnvironment
 
@@ -390,8 +390,8 @@ class PostProcessing(RunEnvironment):
                                                    use_multiprocessing=True, verbose=0, steps=1)
         path = self.data_store.get("model_path")
         with open(os.path.join(path, "test_scores.txt"), "a") as f:
-            for index, item in enumerate(test_score):
-                logging.info(f"{self.model.metrics_names[index]}, {item}")
+            for index, item in enumerate(to_list(test_score)):
+                logging.info(f"{self.model.metrics_names[index]} (test), {item}")
                 f.write(f"{self.model.metrics_names[index]}, {item}\n")
 
     def train_ols_model(self):
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 6c993d56b540cf3cf5b86d9c1920fc3a22557e46..d4badfe25c94133a53a93ca69f1b2f63a955803c 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -16,6 +16,7 @@ from mlair.model_modules.keras_extensions import CallbackHandler
 from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.configuration import path_config
+from mlair.helpers import to_list
 
 
 class Training(RunEnvironment):
@@ -246,4 +247,8 @@ class Training(RunEnvironment):
         path_config.check_path_and_create(path)
         df.to_latex(os.path.join(path, "training_settings.tex"), na_rep='---', column_format=column_format)
         df.to_markdown(open(os.path.join(path, "training_settings.md"), mode="w", encoding='utf-8'),
-                       tablefmt="github")
\ No newline at end of file
+                       tablefmt="github")
+
+        val_score = self.model.evaluate_generator(generator=self.val_set, use_multiprocessing=True, verbose=0, steps=1)
+        for index, item in enumerate(to_list(val_score)):
+            logging.info(f"{self.model.metrics_names[index]} (val), {item}")
diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py
index ade5c19215e61de5e209db900920187294ac9b18..e47d725a4fd78fec98e81a6de9c18869e7b47637 100644
--- a/test/test_data_handler/test_iterator.py
+++ b/test/test_data_handler/test_iterator.py
@@ -1,7 +1,7 @@
-
 from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator
 from mlair.helpers.testing import PyTestAllEqual
-from mlair.model_modules.model_class import MyLittleModel, MyBranchedModel
+from mlair.model_modules.model_class import MyBranchedModel
+from mlair.model_modules.fully_connected_networks import FCN_64_32_16
 
 import numpy as np
 import pytest
@@ -275,7 +275,7 @@ class TestKerasIterator:
 
     def test_get_model_rank_single_output_branch(self):
         iterator = object.__new__(KerasIterator)
-        iterator.model = MyLittleModel(input_shape=[(14, 1, 2)], output_shape=[(3,)])
+        iterator.model = FCN_64_32_16(input_shape=[(14, 1, 2)], output_shape=[(3,)])
         assert iterator._get_model_rank() == 1
 
     def test_get_model_rank_multiple_output_branch(self):
diff --git a/test/test_model_modules/test_abstract_model_class.py b/test/test_model_modules/test_abstract_model_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfef68d550b07f824ed38e5c7809c00e5386d115
--- /dev/null
+++ b/test/test_model_modules/test_abstract_model_class.py
@@ -0,0 +1,199 @@
+import keras
+import pytest
+
+from mlair import AbstractModelClass
+
+
+class Paddings:
+    allowed_paddings = {"pad1": 34, "another_pad": True}
+
+
+class AbstractModelSubClass(AbstractModelClass):
+
+    def __init__(self):
+        super().__init__(input_shape=(12, 1, 2), output_shape=3)
+        self.test_attr = "testAttr"
+
+
+class TestAbstractModelClass:
+
+    @pytest.fixture
+    def amc(self):
+        return AbstractModelClass(input_shape=(14, 1, 2), output_shape=(3,))
+
+    @pytest.fixture
+    def amsc(self):
+        return AbstractModelSubClass()
+
+    def test_init(self, amc):
+        assert amc.model is None
+        # assert amc.loss is None
+        assert amc.model_name == "AbstractModelClass"
+        assert amc.custom_objects == {}
+        assert amc._input_shape == (14, 1, 2)
+        assert amc._output_shape == 3
+
+    def test_model_property(self, amc):
+        amc.model = keras.Model()
+        assert isinstance(amc.model, keras.Model) is True
+
+    # def test_loss_property(self, amc):
+    #     amc.loss = keras.losses.mean_absolute_error
+    #     assert amc.loss == keras.losses.mean_absolute_error
+
+    def test_compile_options_setter_all_empty(self, amc):
+        amc.compile_options = None
+        assert amc.compile_options == {'optimizer': None,
+                                       'loss': None,
+                                       'metrics': None,
+                                       'loss_weights': None,
+                                       'sample_weight_mode': None,
+                                       'weighted_metrics': None,
+                                       'target_tensors': None
+                                       }
+
+    def test_compile_options_setter_as_dict(self, amc):
+        amc.compile_options = {"optimizer": keras.optimizers.SGD(),
+                               "loss": keras.losses.mean_absolute_error,
+                               "metrics": ["mse", "mae"]}
+        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
+        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
+        assert amc.compile_options["metrics"] == ["mse", "mae"]
+        assert amc.compile_options["loss_weights"] is None
+        assert amc.compile_options["sample_weight_mode"] is None
+        assert amc.compile_options["target_tensors"] is None
+        assert amc.compile_options["weighted_metrics"] is None
+
+    def test_compile_options_setter_as_attr(self, amc):
+        amc.optimizer = keras.optimizers.SGD()
+        amc.loss = keras.losses.mean_absolute_error
+        amc.compile_options = None  # This line has to be called!
+        # optimizer check
+        assert isinstance(amc.optimizer, keras.optimizers.SGD)
+        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
+        # loss check
+        assert amc.loss == keras.losses.mean_absolute_error
+        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
+        # check rest (all None as not set)
+        assert amc.compile_options["metrics"] is None
+        assert amc.compile_options["loss_weights"] is None
+        assert amc.compile_options["sample_weight_mode"] is None
+        assert amc.compile_options["target_tensors"] is None
+        assert amc.compile_options["weighted_metrics"] is None
+
+    def test_compile_options_setter_as_mix_attr_dict_no_duplicates(self, amc):
+        amc.optimizer = keras.optimizers.SGD()
+        amc.compile_options = {"loss": keras.losses.mean_absolute_error,
+                               "loss_weights": [0.2, 0.8]}
+        # check setting by attribute
+        assert isinstance(amc.optimizer, keras.optimizers.SGD)
+        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
+        # check setting by dict
+        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
+        assert amc.compile_options["loss_weights"] == [0.2, 0.8]
+        # check rest (all None as not set)
+        assert amc.compile_options["metrics"] is None
+        assert amc.compile_options["sample_weight_mode"] is None
+        assert amc.compile_options["target_tensors"] is None
+        assert amc.compile_options["weighted_metrics"] is None
+
+    def test_compile_options_setter_as_mix_attr_dict_valid_duplicates_optimizer(self, amc):
+        amc.optimizer = keras.optimizers.SGD()
+        amc.metrics = ['mse']
+        amc.compile_options = {"optimizer": keras.optimizers.SGD(),
+                               "loss": keras.losses.mean_absolute_error}
+        # check duplicate (attr and dic)
+        assert isinstance(amc.optimizer, keras.optimizers.SGD)
+        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
+        # check setting by dict
+        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
+        # check setting by attr
+        assert amc.metrics == ['mse']
+        assert amc.compile_options["metrics"] == ['mse']
+        # check rest (all None as not set)
+        assert amc.compile_options["loss_weights"] is None
+        assert amc.compile_options["sample_weight_mode"] is None
+        assert amc.compile_options["target_tensors"] is None
+        assert amc.compile_options["weighted_metrics"] is None
+
+    def test_compile_options_setter_as_mix_attr_dict_valid_duplicates_none_optimizer(self, amc):
+        amc.optimizer = keras.optimizers.SGD()
+        amc.metrics = ['mse']
+        amc.compile_options = {"metrics": ['mse'],
+                               "loss": keras.losses.mean_absolute_error}
+        # check duplicate (attr and dic)
+        assert amc.metrics == ['mse']
+        assert amc.compile_options["metrics"] == ['mse']
+        # check setting by dict
+        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
+        # check setting by attr
+        assert isinstance(amc.optimizer, keras.optimizers.SGD)
+        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
+        # check rest (all None as not set)
+        assert amc.compile_options["loss_weights"] is None
+        assert amc.compile_options["sample_weight_mode"] is None
+        assert amc.compile_options["target_tensors"] is None
+        assert amc.compile_options["weighted_metrics"] is None
+
+    def test_compile_options_property_type_error(self, amc):
+        with pytest.raises(TypeError) as einfo:
+            amc.compile_options = 'hello world'
+        assert "`compile_options' must be `dict' or `None', but is <class 'str'>." in str(einfo.value)
+
+    def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_other_optimizer(self, amc):
+        amc.optimizer = keras.optimizers.SGD()
+        with pytest.raises(ValueError) as einfo:
+            amc.compile_options = {"optimizer": keras.optimizers.Adam()}
+        assert "Got different values or arguments for same argument: self.optimizer=<class" \
+               " 'keras.optimizers.SGD'> and 'optimizer': <class 'keras.optimizers.Adam'>" in str(einfo.value)
+
+    def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_same_optimizer_other_args(self, amc):
+        amc.optimizer = keras.optimizers.SGD(lr=0.1)
+        with pytest.raises(ValueError) as einfo:
+            amc.compile_options = {"optimizer": keras.optimizers.SGD(lr=0.001)}
+        assert "Got different values or arguments for same argument: self.optimizer=<class" \
+               " 'keras.optimizers.SGD'> and 'optimizer': <class 'keras.optimizers.SGD'>" in str(einfo.value)
+
+    def test_compile_options_setter_as_dict_invalid_keys(self, amc):
+        with pytest.raises(ValueError) as einfo:
+            amc.compile_options = {"optimizer": keras.optimizers.SGD(), "InvalidKeyword": [1, 2, 3]}
+        assert "Got invalid key for compile_options. dict_keys(['optimizer', 'InvalidKeyword'])" in str(einfo.value)
+
+    def test_compare_keras_optimizers_equal(self, amc):
+        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(), keras.optimizers.SGD()) is True
+
+    def test_compare_keras_optimizers_no_optimizer(self, amc):
+        assert amc._AbstractModelClass__compare_keras_optimizers('NoOptimizer', keras.optimizers.SGD()) is False
+
+    def test_compare_keras_optimizers_other_parameters_run_sess(self, amc):
+        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(lr=0.1),
+                                                                 keras.optimizers.SGD(lr=0.01)) is False
+
+    def test_compare_keras_optimizers_other_parameters_none_sess(self, amc):
+        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(decay=1),
+                                                                 keras.optimizers.SGD(decay=0.01)) is False
+
+    def test_getattr(self, amc):
+        amc.model = keras.Model()
+        assert hasattr(amc, "compile") is True
+        assert hasattr(amc.model, "compile") is True
+        assert amc.compile == amc.model.compile
+
+    def test_get_settings(self, amc, amsc):
+        assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2),
+                                      "_output_shape": 3}
+        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
+                                       "_input_shape": (12, 1, 2), "_output_shape": 3}
+
+    def test_custom_objects(self, amc):
+        amc.custom_objects = {"Test": 123}
+        assert amc.custom_objects == {"Test": 123}
+
+    def test_set_custom_objects(self, amc):
+        amc.set_custom_objects(Test=22, minor_param="minor")
+        assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
+        amc.set_custom_objects(Test=2, minor_param1="minor1")
+        assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
+        paddings = Paddings()
+        amc.set_custom_objects(Test=1, Padding2D=paddings)
+        assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}
diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py
index 28218eb60e23d6e5b0e361fc2617398aade799cc..cbff4cec6c5b002c3166954880e1008e7f4d7ae3 100644
--- a/test/test_model_modules/test_model_class.py
+++ b/test/test_model_modules/test_model_class.py
@@ -1,205 +1,9 @@
 import keras
 import pytest
 
-from mlair.model_modules.model_class import AbstractModelClass
 from mlair.model_modules.model_class import MyPaperModel
 
 
-class Paddings:
-    allowed_paddings = {"pad1": 34, "another_pad": True}
-
-
-class AbstractModelSubClass(AbstractModelClass):
-
-    def __init__(self):
-        super().__init__(input_shape=(12, 1, 2), output_shape=3)
-        self.test_attr = "testAttr"
-
-
-class TestAbstractModelClass:
-
-    @pytest.fixture
-    def amc(self):
-        return AbstractModelClass(input_shape=(14, 1, 2), output_shape=(3,))
-
-    @pytest.fixture
-    def amsc(self):
-        return AbstractModelSubClass()
-
-    def test_init(self, amc):
-        assert amc.model is None
-        # assert amc.loss is None
-        assert amc.model_name == "AbstractModelClass"
-        assert amc.custom_objects == {}
-        assert amc._input_shape == (14, 1, 2)
-        assert amc._output_shape == 3
-
-    def test_model_property(self, amc):
-        amc.model = keras.Model()
-        assert isinstance(amc.model, keras.Model) is True
-
-    # def test_loss_property(self, amc):
-    #     amc.loss = keras.losses.mean_absolute_error
-    #     assert amc.loss == keras.losses.mean_absolute_error
-
-    def test_compile_options_setter_all_empty(self, amc):
-        amc.compile_options = None
-        assert amc.compile_options == {'optimizer': None,
-                                       'loss': None,
-                                       'metrics': None,
-                                       'loss_weights': None,
-                                       'sample_weight_mode': None,
-                                       'weighted_metrics': None,
-                                       'target_tensors': None
-                                       }
-
-    def test_compile_options_setter_as_dict(self, amc):
-        amc.compile_options = {"optimizer": keras.optimizers.SGD(),
-                               "loss": keras.losses.mean_absolute_error,
-                               "metrics": ["mse", "mae"]}
-        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
-        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
-        assert amc.compile_options["metrics"] == ["mse", "mae"]
-        assert amc.compile_options["loss_weights"] is None
-        assert amc.compile_options["sample_weight_mode"] is None
-        assert amc.compile_options["target_tensors"] is None
-        assert amc.compile_options["weighted_metrics"] is None
-
-    def test_compile_options_setter_as_attr(self, amc):
-        amc.optimizer = keras.optimizers.SGD()
-        amc.loss = keras.losses.mean_absolute_error
-        amc.compile_options = None  # This line has to be called!
-        # optimizer check
-        assert isinstance(amc.optimizer, keras.optimizers.SGD)
-        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
-        # loss check
-        assert amc.loss == keras.losses.mean_absolute_error
-        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
-        # check rest (all None as not set)
-        assert amc.compile_options["metrics"] is None
-        assert amc.compile_options["loss_weights"] is None
-        assert amc.compile_options["sample_weight_mode"] is None
-        assert amc.compile_options["target_tensors"] is None
-        assert amc.compile_options["weighted_metrics"] is None
-
-    def test_compile_options_setter_as_mix_attr_dict_no_duplicates(self, amc):
-        amc.optimizer = keras.optimizers.SGD()
-        amc.compile_options = {"loss": keras.losses.mean_absolute_error,
-                               "loss_weights": [0.2, 0.8]}
-        # check setting by attribute
-        assert isinstance(amc.optimizer, keras.optimizers.SGD)
-        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
-        # check setting by dict
-        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
-        assert amc.compile_options["loss_weights"] == [0.2, 0.8]
-        # check rest (all None as not set)
-        assert amc.compile_options["metrics"] is None
-        assert amc.compile_options["sample_weight_mode"] is None
-        assert amc.compile_options["target_tensors"] is None
-        assert amc.compile_options["weighted_metrics"] is None
-
-    def test_compile_options_setter_as_mix_attr_dict_valid_duplicates_optimizer(self, amc):
-        amc.optimizer = keras.optimizers.SGD()
-        amc.metrics = ['mse']
-        amc.compile_options = {"optimizer": keras.optimizers.SGD(),
-                               "loss": keras.losses.mean_absolute_error}
-        # check duplicate (attr and dic)
-        assert isinstance(amc.optimizer, keras.optimizers.SGD)
-        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
-        # check setting by dict
-        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
-        # check setting by attr
-        assert amc.metrics == ['mse']
-        assert amc.compile_options["metrics"] == ['mse']
-        # check rest (all None as not set)
-        assert amc.compile_options["loss_weights"] is None
-        assert amc.compile_options["sample_weight_mode"] is None
-        assert amc.compile_options["target_tensors"] is None
-        assert amc.compile_options["weighted_metrics"] is None
-
-    def test_compile_options_setter_as_mix_attr_dict_valid_duplicates_none_optimizer(self, amc):
-        amc.optimizer = keras.optimizers.SGD()
-        amc.metrics = ['mse']
-        amc.compile_options = {"metrics": ['mse'],
-                               "loss": keras.losses.mean_absolute_error}
-        # check duplicate (attr and dic)
-        assert amc.metrics == ['mse']
-        assert amc.compile_options["metrics"] == ['mse']
-        # check setting by dict
-        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
-        # check setting by attr
-        assert isinstance(amc.optimizer, keras.optimizers.SGD)
-        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
-        # check rest (all None as not set)
-        assert amc.compile_options["loss_weights"] is None
-        assert amc.compile_options["sample_weight_mode"] is None
-        assert amc.compile_options["target_tensors"] is None
-        assert amc.compile_options["weighted_metrics"] is None
-
-    def test_compile_options_property_type_error(self, amc):
-        with pytest.raises(TypeError) as einfo:
-            amc.compile_options = 'hello world'
-        assert "`compile_options' must be `dict' or `None', but is <class 'str'>." in str(einfo.value)
-
-    def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_other_optimizer(self, amc):
-        amc.optimizer = keras.optimizers.SGD()
-        with pytest.raises(ValueError) as einfo:
-            amc.compile_options = {"optimizer": keras.optimizers.Adam()}
-        assert "Got different values or arguments for same argument: self.optimizer=<class" \
-               " 'keras.optimizers.SGD'> and 'optimizer': <class 'keras.optimizers.Adam'>" in str(einfo.value)
-
-    def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_same_optimizer_other_args(self, amc):
-        amc.optimizer = keras.optimizers.SGD(lr=0.1)
-        with pytest.raises(ValueError) as einfo:
-            amc.compile_options = {"optimizer": keras.optimizers.SGD(lr=0.001)}
-        assert "Got different values or arguments for same argument: self.optimizer=<class" \
-               " 'keras.optimizers.SGD'> and 'optimizer': <class 'keras.optimizers.SGD'>" in str(einfo.value)
-
-    def test_compile_options_setter_as_dict_invalid_keys(self, amc):
-        with pytest.raises(ValueError) as einfo:
-            amc.compile_options = {"optimizer": keras.optimizers.SGD(), "InvalidKeyword": [1, 2, 3]}
-        assert "Got invalid key for compile_options. dict_keys(['optimizer', 'InvalidKeyword'])" in str(einfo.value)
-
-    def test_compare_keras_optimizers_equal(self, amc):
-        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(), keras.optimizers.SGD()) is True
-
-    def test_compare_keras_optimizers_no_optimizer(self, amc):
-        assert amc._AbstractModelClass__compare_keras_optimizers('NoOptimizer', keras.optimizers.SGD()) is False
-
-    def test_compare_keras_optimizers_other_parameters_run_sess(self, amc):
-        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(lr=0.1),
-                                                                 keras.optimizers.SGD(lr=0.01)) is False
-
-    def test_compare_keras_optimizers_other_parameters_none_sess(self, amc):
-        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(decay=1),
-                                                                 keras.optimizers.SGD(decay=0.01)) is False
-
-    def test_getattr(self, amc):
-        amc.model = keras.Model()
-        assert hasattr(amc, "compile") is True
-        assert hasattr(amc.model, "compile") is True
-        assert amc.compile == amc.model.compile
-
-    def test_get_settings(self, amc, amsc):
-        assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2),
-                                      "_output_shape": 3}
-        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
-                                       "_input_shape": (12, 1, 2), "_output_shape": 3}
-
-    def test_custom_objects(self, amc):
-        amc.custom_objects = {"Test": 123}
-        assert amc.custom_objects == {"Test": 123}
-
-    def test_set_custom_objects(self, amc):
-        amc.set_custom_objects(Test=22, minor_param="minor")
-        assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
-        amc.set_custom_objects(Test=2, minor_param1="minor1")
-        assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
-        paddings = Paddings()
-        amc.set_custom_objects(Test=1, Padding2D=paddings)
-        assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}
-
-
 class TestMyPaperModel:
 
     @pytest.fixture
diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py
index 382105344dfb9fffb37215f2706dda1f2ebd90ea..8a7572148869537b505b2bd8e7f16cfdf7af1cdd 100644
--- a/test/test_run_modules/test_model_setup.py
+++ b/test/test_run_modules/test_model_setup.py
@@ -8,7 +8,8 @@ from mlair.data_handler import KerasIterator
 from mlair.data_handler import DataCollection
 from mlair.helpers.datastore import EmptyScope
 from mlair.model_modules.keras_extensions import CallbackHandler
-from mlair.model_modules.model_class import AbstractModelClass, MyLittleModel
+from mlair.model_modules.fully_connected_networks import FCN_64_32_16
+from mlair.model_modules import AbstractModelClass
 from mlair.run_modules.model_setup import ModelSetup
 from mlair.run_modules.run_environment import RunEnvironment
 
@@ -22,7 +23,7 @@ class TestModelSetup:
         obj.scope = "general.model"
         obj.model = None
         obj.callbacks_name = "placeholder_%s_str.pickle"
-        obj.data_store.set("model_class", MyLittleModel)
+        obj.data_store.set("model_class", FCN_64_32_16)
         obj.data_store.set("lr_decay", "dummy_str", "general.model")
         obj.data_store.set("hist", "dummy_str", "general.model")
         obj.data_store.set("epochs", 2)
@@ -102,8 +103,7 @@ class TestModelSetup:
         assert setup_with_gen.model is None
         setup_with_gen.build_model()
         assert isinstance(setup_with_gen.model, AbstractModelClass)
-        expected = {"lr_decay", "model_name", "dropout_rate", "regularizer", "initial_lr", "optimizer", "activation",
-                    "input_shape", "output_shape"}
+        expected = {"lr_decay", "model_name", "optimizer", "activation", "input_shape", "output_shape"}
         assert expected <= self.current_scope_as_set(setup_with_gen)
 
     def test_set_shapes(self, setup_with_gen_tiny):