diff --git a/src/helpers.py b/src/helpers.py
index be73614319b39dc36043437c64379342a96ce00e..d4180336ec63f4f5477d3f2a149b5cb146be5597 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -11,6 +11,8 @@ import math
 import os
 import socket
 import time
+import types
+
 
 import keras.backend as K
 import xarray as xr
@@ -53,6 +55,9 @@ class TimeTrackingWrapper:
         with TimeTracking(name=self.__wrapped__.__name__):
             return self.__wrapped__(*args, **kwargs)
 
+    def __get__(self, instance, cls):
+        return types.MethodType(self, instance)
+
 
 class TimeTracking(object):
     """
diff --git a/src/model_modules/advanced_paddings.py b/src/model_modules/advanced_paddings.py
index d9e55c78fb6c78bbe219c820078c46a235627897..ea16e5b8a7c6a01456e286a2afaab4d5a88c96cc 100644
--- a/src/model_modules/advanced_paddings.py
+++ b/src/model_modules/advanced_paddings.py
@@ -254,10 +254,24 @@ class SymmetricPadding2D(_ZeroPadding):
 
 
 class Padding2D:
-    '''
-    This class combines the implemented padding methods. You can call this method by defining a specific padding type.
-    The __call__ method will return the corresponding Padding layer.
-    '''
+    """
+    Combine all implemented padding methods.
+
+    You can call this method by defining a specific padding type. The __call__ method will return the corresponding
+    Padding layer.
+
+    .. code-block:: python
+
+        input_x = ... #  your input data
+        kernel_size = (5, 1)
+        padding_size = PadUtils.get_padding_for_same(kernel_size)
+
+        tower = layers.Conv2D(...)(input_x)
+        tower = layers.Activation(...)(tower)
+        tower = Padding2D('ZeroPad2D')(padding=padding_size, name=f'Custom_Pad')(tower)
+
+    Padding type can either be set by a string or directly by using an instance of a valid padding class.
+    """
 
     allowed_paddings = {
         **dict.fromkeys(("RefPad2D", "ReflectionPadding2D"), ReflectionPadding2D),
diff --git a/src/model_modules/flatten.py b/src/model_modules/flatten.py
index bbe92472ebb48e7486dede099dc098a161f51695..dd1e8e21eeb96f75372add0208b03dc06f5dc25c 100644
--- a/src/model_modules/flatten.py
+++ b/src/model_modules/flatten.py
@@ -1,33 +1,102 @@
 __author__ = "Felix Kleinert, Lukas Leufen"
 __date__ = '2019-12-02'
 
-from typing import Callable
+from typing import Union, Callable
 
 import keras
 
 
-def flatten_tail(input_X: keras.layers, name: str, bound_weight: bool = False, dropout_rate: float = 0.0,
-                 window_lead_time: int = 4, activation: Callable = keras.activations.relu,
-                 reduction_filter: int = 64, first_dense: int = 64):
+def get_activation(input_to_activate: keras.layers, activation: Union[Callable, str], **kwargs):
+    """
+    Apply activation on a given input layer.
 
-    X_in = keras.layers.Conv2D(reduction_filter, (1, 1), padding='same', name='{}_Conv_1x1'.format(name))(input_X)
+    This helper function is able to handle advanced keras activations as well as strings for standard activations.
 
-    X_in = activation(name='{}_conv_act'.format(name))(X_in)
+    :param input_to_activate: keras layer to apply activation on
+    :param activation: activation to apply on `input_to_activate'. Can be a standard keras strings or activation layers
+    :param kwargs: keyword arguments used inside activation layer
 
-    X_in = keras.layers.Flatten(name='{}'.format(name))(X_in)
+    :return: activation
 
-    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format(name))(X_in)
-    X_in = keras.layers.Dense(first_dense, kernel_regularizer=keras.regularizers.l2(0.01),
-                              name='{}_Dense_1'.format(name))(X_in)
+    .. code-block:: python
+
+        input_x = ... # your input data
+        x_in = keras.layer(<without activation>)(input_x)
+
+        # get activation via string
+        x_act_string = get_activation(x_in, 'relu')
+        # or get activation via layer callable
+        x_act_layer = get_activation(x_in, keras.layers.advanced_activations.ELU)
+
+    """
+    if isinstance(activation, str):
+        name = kwargs.pop('name', None)
+        kwargs['name'] = f'{name}_{activation}'
+        act = keras.layers.Activation(activation, **kwargs)(input_to_activate)
+    else:
+        act = activation(**kwargs)(input_to_activate)
+    return act
+
+
+def flatten_tail(input_x: keras.layers, inner_neurons: int, activation: Union[Callable, str],
+                 output_neurons: int, output_activation: Union[Callable, str],
+                 reduction_filter: int = None,
+                 name: str = None,
+                 bound_weight: bool = False,
+                 dropout_rate: float = None,
+                 kernel_regularizer: keras.regularizers = None
+                 ):
+    """
+    Flatten output of convolutional layers.
+
+    :param input_x: Multidimensional keras layer (ConvLayer)
+    :param output_neurons: Number of neurons in the last layer (must fit the shape of labels)
+    :param output_activation: final activation function
+    :param name: Name of the flatten tail.
+    :param bound_weight: Use `tanh' as inner activation if set to True, otherwise `activation'
+    :param dropout_rate: Dropout rate to be applied between trainable layers
+    :param activation: activation to after conv and dense layers
+    :param reduction_filter: number of filters used for information compression on `input_x' before flatten()
+    :param inner_neurons: Number of neurons in inner dense layer
+    :param kernel_regularizer: regularizer to apply on conv and dense layers
+
+    :return: flatten branch with size n=output_neurons
+
+    .. code-block:: python
+
+        input_x = ... # your input data
+        conv_out = Conv2D(*args)(input_x) # your convolution stack
+        out = flatten_tail(conv_out, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
+                           output_neurons=4
+                           output_activation='linear', reduction_filter=64,
+                           name='Main', bound_weight=False, dropout_rate=.3,
+                           kernel_regularizer=keras.regularizers.l2()
+                           )
+        model = keras.Model(inputs=input_x, outputs=[out])
+
+    """
+    # compression layer
+    if reduction_filter is None:
+        x_in = input_x
+    else:
+        x_in = keras.layers.Conv2D(reduction_filter, (1, 1), name=f'{name}_Conv_1x1',
+                                   kernel_regularizer=kernel_regularizer)(input_x)
+        x_in = get_activation(x_in, activation, name=f'{name}_conv_act')
+
+    x_in = keras.layers.Flatten(name='{}'.format(name))(x_in)
+
+    if dropout_rate is not None:
+        x_in = keras.layers.Dropout(dropout_rate, name=f'{name}_Dropout_1')(x_in)
+    x_in = keras.layers.Dense(inner_neurons, kernel_regularizer=kernel_regularizer,
+                              name=f'{name}_inner_Dense')(x_in)
     if bound_weight:
-        X_in = keras.layers.Activation('tanh')(X_in)
+        x_in = keras.layers.Activation('tanh')(x_in)
     else:
-        try:
-            X_in = activation(name='{}_act'.format(name))(X_in)
-        except:
-            X_in = activation()(X_in)
-
-    X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_2'.format(name))(X_in)
-    out = keras.layers.Dense(window_lead_time, activation='linear', kernel_regularizer=keras.regularizers.l2(0.01),
-                             name='{}_Dense_2'.format(name))(X_in)
+        x_in = get_activation(x_in, activation, name=f'{name}_act')
+
+    if dropout_rate is not None:
+        x_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_2'.format(name))(x_in)
+    out = keras.layers.Dense(output_neurons, kernel_regularizer=kernel_regularizer,
+                             name=f'{name}_out_Dense')(x_in)
+    out = get_activation(out, output_activation, name=f'{name}_final_act')
     return out
diff --git a/src/model_modules/inception_model.py b/src/model_modules/inception_model.py
index 6467b3245ad097af6ef17e596f85264eef383d7a..15739556d7d28d9e7e6ecc454615d82fb81a2754 100644
--- a/src/model_modules/inception_model.py
+++ b/src/model_modules/inception_model.py
@@ -75,12 +75,9 @@ class InceptionModelBase:
                                   name=f'Block_{self.number_of_blocks}{self.block_part_name()}_1x1')(input_x)
             tower = self.act(tower, activation, **act_settings)
 
-            # tower = self.padding_layer(padding)(padding=padding_size,
-            #                                     name=f'Block_{self.number_of_blocks}{self.block_part_name()}_Pad'
-            #                                     )(tower)
             tower = Padding2D(padding)(padding=padding_size,
-                                                name=f'Block_{self.number_of_blocks}{self.block_part_name()}_Pad'
-                                                )(tower)
+                                       name=f'Block_{self.number_of_blocks}{self.block_part_name()}_Pad'
+                                       )(tower)
 
             tower = layers.Conv2D(tower_filter,
                                   tower_kernel,
@@ -111,29 +108,6 @@ class InceptionModelBase:
         else:
             return act_name.__name__
 
-    # @staticmethod
-    # def padding_layer(padding):
-    #     allowed_paddings = {
-    #         'RefPad2D': ReflectionPadding2D, 'ReflectionPadding2D': ReflectionPadding2D,
-    #         'SymPad2D': SymmetricPadding2D, 'SymmetricPadding2D': SymmetricPadding2D,
-    #         'ZeroPad2D': keras.layers.ZeroPadding2D, 'ZeroPadding2D': keras.layers.ZeroPadding2D
-    #     }
-    #     if isinstance(padding, str):
-    #         try:
-    #             pad2d = allowed_paddings[padding]
-    #         except KeyError as einfo:
-    #             raise NotImplementedError(
-    #                 f"`{einfo}' is not implemented as padding. "
-    #                 "Use one of those: i) `RefPad2D', ii) `SymPad2D', iii) `ZeroPad2D'")
-    #     else:
-    #         if padding in allowed_paddings.values():
-    #             pad2d = padding
-    #         else:
-    #             raise TypeError(f"`{padding.__name__}' is not a valid padding layer type. "
-    #                             "Use one of those: "
-    #                             "i) ReflectionPadding2D, ii) SymmetricPadding2D, iii) ZeroPadding2D")
-    #     return pad2d
-
     def create_pool_tower(self, input_x, pool_kernel, tower_filter, activation='relu', max_pooling=True, **kwargs):
         """
         This function creates a "MaxPooling tower block"
@@ -159,7 +133,6 @@ class InceptionModelBase:
             block_type = "AvgPool"
             pooling = layers.AveragePooling2D
 
-        # tower = self.padding_layer(padding)(padding=padding_size, name=block_name+'Pad')(input_x)
         tower = Padding2D(padding)(padding=padding_size, name=block_name+'Pad')(input_x)
         tower = pooling(pool_kernel, strides=(1, 1), padding='valid', name=block_name+block_type)(tower)
 
@@ -215,35 +188,6 @@ class InceptionModelBase:
         return block
 
 
-# if __name__ == '__main__':
-#     from keras.models import Model
-#     from keras.layers import Conv2D, Flatten, Dense, Input
-#     import numpy as np
-#
-#
-#     kernel_1 = (3, 3)
-#     kernel_2 = (5, 5)
-#     x = np.array(range(2000)).reshape(-1, 10, 10, 1)
-#     y = x.mean(axis=(1, 2))
-#
-#     x_input = Input(shape=x.shape[1:])
-#     pad1 = PadUtils.get_padding_for_same(kernel_size=kernel_1)
-#     x_out = InceptionModelBase.padding_layer('RefPad2D')(padding=pad1, name="RefPAD1")(x_input)
-#     # x_out = ReflectionPadding2D(padding=pad1, name="RefPAD")(x_input)
-#     x_out = Conv2D(5, kernel_size=kernel_1, activation='relu')(x_out)
-#
-#     pad2 = PadUtils.get_padding_for_same(kernel_size=kernel_2)
-#     x_out = InceptionModelBase.padding_layer(SymmetricPadding2D)(padding=pad2, name="SymPAD1")(x_out)
-#     # x_out = SymmetricPadding2D(padding=pad2, name="SymPAD")(x_out)
-#     x_out = Conv2D(2, kernel_size=kernel_2, activation='relu')(x_out)
-#     x_out = Flatten()(x_out)
-#     x_out = Dense(1, activation='linear')(x_out)
-#
-#     model = Model(inputs=x_input, outputs=x_out)
-#     model.compile('adam', loss='mse')
-#     model.summary()
-#     # model.fit(x, y, epochs=10)
-
 if __name__ == '__main__':
     print(__name__)
     from keras.datasets import cifar10
diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py
index 0064c795e9bba162fafe3e9d5f60a17b95a4d57f..290a527b0ae2ccc9e17ddf5ed49098a4a55a173b 100644
--- a/src/model_modules/model_class.py
+++ b/src/model_modules/model_class.py
@@ -427,8 +427,14 @@ class MyTowerModel(AbstractModelClass):
                                                batch_normalisation=True)
         #############################################
 
-        out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
-                                reduction_filter=64, first_dense=64, window_lead_time=self.window_lead_time)
+        # out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate,
+        #                         reduction_filter=64, inner_neurons=64, output_neurons=self.window_lead_time)
+
+        out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+                                output_activation='linear', reduction_filter=64,
+                                name='Main', bound_weight=True, dropout_rate=self.dropout_rate,
+                                kernel_regularizer=self.regularizer
+                                )
 
         self.model = keras.Model(inputs=X_input, outputs=[out_main])
 
@@ -550,8 +556,13 @@ class MyPaperModel(AbstractModelClass):
                                                regularizer=self.regularizer,
                                                batch_normalisation=True,
                                                padding=self.padding)
-        out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time,
-                                  self.activation, 32, 64)
+        # out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time,
+        #                           self.activation, 32, 64)
+        out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time,
+                                  output_activation='linear', reduction_filter=32,
+                                  name='minor_1', bound_weight=False, dropout_rate=self.dropout_rate,
+                                  kernel_regularizer=self.regularizer
+                                  )
 
         X_in = keras.layers.Dropout(self.dropout_rate)(X_in)
 
@@ -564,8 +575,11 @@ class MyPaperModel(AbstractModelClass):
         #                                        batch_normalisation=True)
         #############################################
 
-        out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=False, dropout_rate=self.dropout_rate,
-                                reduction_filter=64 * 2, first_dense=64 * 2, window_lead_time=self.window_lead_time)
+        out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.window_lead_time,
+                                output_activation='linear',  reduction_filter=64 * 2,
+                                name='Main', bound_weight=False, dropout_rate=self.dropout_rate,
+                                kernel_regularizer=self.regularizer
+                                )
 
         self.model = keras.Model(inputs=X_input, outputs=[out_minor1, out_main])
 
diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index 14e3074a7d8f09bd597fb2fbf53a298d83ab6556..b61e832c80ac9ad83a5aa4a4b5310b17f6add098 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -189,26 +189,12 @@ class PlotStationMap(AbstractPlotClass):
 
 
 @TimeTrackingWrapper
-def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'obs',
-                               pred_name: str = 'CNN', season: str = "", forecast_path: str = None,
-                               plot_name_affix: str = "", units: str = "ppb"):
+class PlotConditionalQuantiles(AbstractPlotClass):
     """
-    This plot was originally taken from Murphy, Brown and Chen (1989):
-    https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2
-
-    :param stations: stations to include in the plot (forecast data needs to be available already)
-    :param plot_folder: path to save the plot (default: current directory)
-    :param rolling_window: the rolling window mean will smooth the plot appearance (no smoothing in bin calculation,
-        this is only a cosmetic step, default: 3)
-    :param ref_name: name of the reference data series
-    :param pred_name: name of the investigated data series
-    :param season: season name to highlight if not empty
-    :param forecast_path: path to save the plot file
-    :param plot_name_affix: name to specify this plot (e.g. 'cali-ref', default: '')
-    :param units: units of the forecasted values (default: ppb)
+    This class creates cond.quantile plots as originally proposed by Murphy, Brown and Chen (1989) [But in log scale]
+
+    Link to paper: https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2
     """
-    # time = TimeTracking()
-    logging.debug(f"started plot_conditional_quantiles()")
     # ignore warnings if nans appear in quantile grouping
     warnings.filterwarnings("ignore", message="All-NaN slice encountered")
     # ignore warnings if mean is calculated on nans
@@ -216,112 +202,238 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w
     # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar)
     warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")
 
-    def load_data():
+    def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True,
+                 rolling_window: int = 3, model_mame: str = "CNN", obs_name: str = "obs", **kwargs):
+        """
+
+        :param stations: all stations to plot
+        :param data_pred_path: path to dir which contains the forecasts as .nc files
+        :param plot_folder: path where the plots are stored
+        :param plot_per_seasons: if `True' create cond. quantile plots for seasons (DJF, MAM, JJA, SON) individually
+        :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.)
+        :param model_mame: name of the model prediction as stored in netCDF file (for example "CNN")
+        :param obs_name: name of observation as stored in netCDF file (for example "obs")
+        :param kwargs: Some further arguments which are listed in self._opts
+        """
+        super().__init__(plot_folder, "conditional_quantiles")
+
+        self._data_pred_path = data_pred_path
+        self._stations = stations
+        self._rolling_window = rolling_window
+        self._model_name = model_mame
+        self._obs_name = obs_name
+
+        self._opts = {"q": kwargs.get("q", [.1, .25, .5, .75, .9]),
+                      "linetype": kwargs.get("linetype", [':', '-.', '--', '-.', ':']),
+                      "legend": kwargs.get("legend",
+                                           ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile',
+                                            'reference 1:1']),
+                      "data_unit": kwargs.get("data_unit", "ppb"),
+                      }
+        if plot_per_seasons is True:
+            self.seasons = ['DJF', 'MAM', 'JJA', 'SON']
+        else:
+            self.seasons = ""
+        self._data = self._load_data()
+        self._bins = self._get_bins_from_rage_of_data()
+
+        self._plot()
+
+    def _load_data(self):
+        """
+        This method loads forcast data
+
+        :return:
+        """
         logging.debug("... load data")
         data_collector = []
-        for station in stations:
-            file = os.path.join(forecast_path, f"forecasts_{station}_test.nc")
+        for station in self._stations:
+            file = os.path.join(self._data_pred_path, f"forecasts_{station}_test.nc")
             data_tmp = xr.open_dataarray(file)
-            data_collector.append(data_tmp.loc[:, :, ['CNN', 'obs', 'OLS']].assign_coords(station=station))
-        return xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station')
+            data_collector.append(data_tmp.loc[:, :, [self._model_name, self._obs_name]].assign_coords(station=station))
+        res = xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station')
+        return res
 
-    def segment_data(data):
+    def _segment_data(self, data, x_model):
+        """
+        This method creates segmented data which is used for cond. quantile plots
+
+        :param data:
+        :param x_model:
+        :return:
+        """
         logging.debug("... segment data")
         # combine index and station to multi index
         data = data.stack(z=['index', 'station'])
         # replace multi index by simple position index (order is not relevant anymore)
         data.coords['z'] = range(len(data.coords['z']))
-        # segment data of pred_name into bins
-        data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins,
-                                                                                labels=bins[1:]).T.values
+        # segment data of x_model into bins
+        data.loc[x_model, ...] = data.loc[x_model, ...].to_pandas().T.apply(pd.cut, bins=self._bins,
+                                                                                labels=self._bins[1:]).T.values
         return data
 
-    def create_quantile_panel(data, q):
+    @staticmethod
+    def _labels(plot_type, data_unit="ppb"):
+        """
+        Helper method to correctly assign (x,y) labels to plots, depending on like-base or cali-ref factorization
+
+        :param plot_type:
+        :param data_unit:
+        :return:
+        """
+        names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})")
+        if plot_type == "obs":
+            return names
+        else:
+            return names[::-1]
+
+    def _get_bins_from_rage_of_data(self):
+        """
+        Get array of bins to use for quantiles
+
+        :return:
+        """
+        return np.arange(0, math.ceil(self._data.max().max()) + 1, 1).astype(int)
+
+    def _create_quantile_panel(self, data, x_model, y_model):
+        """
+        Clculate quantiles
+
+        :param data:
+        :param x_model:
+        :param y_model:
+        :return:
+        """
         logging.debug("... create quantile panel")
         # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step)
-        quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan),
-                                      coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories'])
+        quantile_panel = xr.DataArray(
+            np.full([data.ahead.shape[0], len(self._opts["q"]), self._bins[1:].shape[0]], np.nan),
+            coords=[data.ahead, self._opts["q"], self._bins[1:]], dims=['ahead', 'quantiles', 'categories'])
         # ensure that the coordinates are in the right order
         quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories')
         # calculate for each bin of the pred_name data the quantiles of the ref_name data
-        for bin in bins[1:]:
-            mask = (data.loc[pred_name, ...] == bin)
-            quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(mask).quantile(q, dim=['z']).T
-
+        for bin in self._bins[1:]:
+            mask = (data.loc[x_model, ...] == bin)
+            quantile_panel.loc[..., bin] = data.loc[y_model, ...].where(mask).quantile(self._opts["q"],
+                                                                                                    dim=['z']).T
         return quantile_panel
 
-    def labels(plot_type, data_unit="ppb"):
-        names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})")
-        if plot_type == "obs":
-            return names
-        else:
-            return names[::-1]
+    @staticmethod
+    def add_affix(x):
+        """
+        Helper method to add additional information on plot name
+
+        :param x:
+        :return:
+        """
+        return f"_{x}" if len(x) > 0 else ""
+
+    def _prepare_plots(self, data, x_model, y_model):
+        """
+        Get segmented_data and quantile_panel
+
+        :param data:
+        :param x_model:
+        :param y_model:
+        :return:
+        """
+        segmented_data = self._segment_data(data, x_model)
+        quantile_panel = self._create_quantile_panel(segmented_data, x_model, y_model)
+        return segmented_data, quantile_panel
+
+    def _plot(self):
+        """
+        Main plot call
+
+        :return:
+        """
+        logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: {(len(self.seasons) + 1) * 2}")
+
+        if len(self.seasons) > 0:
+            self._plot_seasons()
+        self._plot_all()
+
+    def _plot_seasons(self):
+        """
+        Seasonal plot call
 
-    xlabel, ylabel = labels(ref_name, units)
-
-    opts = {"q": [.1, .25, .5, .75, .9], "linetype": [':', '-.', '--', '-.', ':'],
-            "legend": ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile', 'reference 1:1'],
-            "xlabel": xlabel, "ylabel": ylabel}
-
-    # set name and path of the plot
-    base_name = "conditional_quantiles"
-    def add_affix(x): return f"_{x}" if len(x) > 0 else ""
-    plot_name = f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf"
-    plot_path = os.path.join(os.path.abspath(plot_folder), plot_name)
-
-    # check forecast path
-    if forecast_path is None:
-        raise ValueError("Forecast path is not given but required.")
-
-    # load data and set data bins
-    orig_data = load_data()
-    bins = np.arange(0, math.ceil(orig_data.max().max()) + 1, 1).astype(int)
-    segmented_data = segment_data(orig_data)
-    quantile_panel = create_quantile_panel(segmented_data, q=opts["q"])
-
-    # init pdf output
-    pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
-    logging.debug(f"... plot path is {plot_path}")
-
-    # create plot for each time step ahead
-    y2_max = 0
-    for iteration, d in enumerate(segmented_data.ahead):
-        logging.debug(f"... plotting {d.values} time step(s) ahead")
-        # plot smoothed lines with rolling mean
-        smooth_data = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T
-        ax = smooth_data.plot(style=opts["linetype"], color='black', legend=False)
-        ax2 = ax.twinx()
-        # add reference line
-        ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8)
-        # add histogram of the segmented data (pred_name)
-        handles, labels = ax.get_legend_handles_labels()
-        segmented_data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False,
-                                                             rwidth=1)
-        # add legend
-        plt.legend(handles[:3] + [handles[-1]], opts["legend"], loc='upper left', fontsize='large')
-        # adjust limits and set labels
-        ax.set(xlim=(0, bins.max()), ylim=(0, bins.max()))
-        ax.set_xlabel(opts["xlabel"], fontsize='x-large')
-        ax.tick_params(axis='x', which='major', labelsize=15)
-        ax.set_ylabel(opts["ylabel"], fontsize='x-large')
-        ax.tick_params(axis='y', which='major', labelsize=15)
-        ax2.yaxis.label.set_color('gray')
-        ax2.tick_params(axis='y', colors='gray')
-        ax2.yaxis.labelpad = -15
-        ax2.set_yscale('log')
-        if iteration == 0:
-            y2_max = ax2.get_ylim()[1] + 100
-        ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5))
-        ax2.set_ylabel('              sample size', fontsize='x-large')
-        ax2.tick_params(axis='y', which='major', labelsize=15)
-        # set title and save current figure
-        title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}"
-        plt.title(title)
-        pdf_pages.savefig()
-    # close all open figures / plots
-    pdf_pages.close()
-    plt.close('all')
-    #logging.info(f"plot_conditional_quantiles() finished after {time}")
+        :return:
+        """
+        for season in self.seasons:
+            self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._model_name,
+                            y_model=self._obs_name, plot_name_affix="cali-ref", season=season)
+            self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._obs_name,
+                            y_model=self._model_name, plot_name_affix="like-base", season=season)
+
+    def _plot_all(self):
+        """
+        Full plot call
+
+        :return:
+        """
+        self._plot_base(data=self._data, x_model=self._model_name, y_model=self._obs_name, plot_name_affix="cali-ref")
+        self._plot_base(data=self._data, x_model=self._obs_name, y_model=self._model_name, plot_name_affix="like-base")
+
+    @TimeTrackingWrapper
+    def _plot_base(self, data, x_model, y_model, plot_name_affix, season=""):
+        """
+        Base method to create cond. quantile plots. Is called from _plot_all and _plot_seasonal
+
+        :param data: data which is used to create cond. quantile plot
+        :param x_model: name of model on x axis (can also be obs)
+        :param y_model: name of model on y axis (can also be obs)
+        :param plot_name_affix: should be `cali-ref' or `like-base'
+        :param season: List of seasons to use
+        :return:
+        """
+
+        segmented_data, quantile_panel = self._prepare_plots(data, x_model, y_model)
+        ylabel, xlabel = self._labels(x_model, self._opts["data_unit"])
+        plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}_plot.pdf"
+        #f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf"
+        plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
+        pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+        logging.debug(f"... plot path is {plot_path}")
+
+        # create plot for each time step ahead
+        y2_max = 0
+        for iteration, d in enumerate(segmented_data.ahead):
+            logging.debug(f"... plotting {d.values} time step(s) ahead")
+            # plot smoothed lines with rolling mean
+            smooth_data = quantile_panel.loc[d, ...].rolling(categories=self._rolling_window,
+                                                             center=True).mean().to_pandas().T
+            ax = smooth_data.plot(style=self._opts["linetype"], color='black', legend=False)
+            ax2 = ax.twinx()
+            # add reference line
+            ax.plot([0, self._bins.max()], [0, self._bins.max()], color='k', label='reference 1:1', linewidth=.8)
+            # add histogram of the segmented data (pred_name)
+            handles, labels = ax.get_legend_handles_labels()
+            segmented_data.loc[x_model, d, :].to_pandas().hist(bins=self._bins, ax=ax2, color='k', alpha=.3, grid=False,
+                                                                 rwidth=1)
+            # add legend
+            plt.legend(handles[:3] + [handles[-1]], self._opts["legend"], loc='upper left', fontsize='large')
+            # adjust limits and set labels
+            ax.set(xlim=(0, self._bins.max()), ylim=(0, self._bins.max()))
+            ax.set_xlabel(xlabel, fontsize='x-large')
+            ax.tick_params(axis='x', which='major', labelsize=15)
+            ax.set_ylabel(ylabel, fontsize='x-large')
+            ax.tick_params(axis='y', which='major', labelsize=15)
+            ax2.yaxis.label.set_color('gray')
+            ax2.tick_params(axis='y', colors='gray')
+            ax2.yaxis.labelpad = -15
+            ax2.set_yscale('log')
+            if iteration == 0:
+                y2_max = ax2.get_ylim()[1] + 100
+            ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5))
+            ax2.set_ylabel('              sample size', fontsize='x-large')
+            ax2.tick_params(axis='y', which='major', labelsize=15)
+            # set title and save current figure
+            title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}"
+            plt.title(title)
+            pdf_pages.savefig()
+        # close all open figures / plots
+        pdf_pages.close()
+        plt.close('all')
 
 
 @TimeTrackingWrapper
@@ -697,7 +809,6 @@ class PlotAvailability(AbstractPlotClass):
                 plt_dict[summary_name].update({subset: t2})
         return plt_dict
 
-
     def _plot(self, plt_dict):
         # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"}  # color names
         colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"}  # hex code
@@ -722,3 +833,12 @@ class PlotAvailability(AbstractPlotClass):
         handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()]
         lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
         return lgd
+
+
+if __name__ == "__main__":
+    stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
+    path = "../../testrun_network/forecasts"
+    plt_path = "../../"
+
+    con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path)
+
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 150399cb2e4997a6b9adfb30dfa3ff89de73d4ac..09b9f143fc0442ee34ef5735366145be86b5fa07 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -21,7 +21,7 @@ DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'max
                         'pblheight': 'maximum'}
 DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
-                     "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "plot_conditional_quantiles",
+                     "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
                      "PlotAvailability"]
 
 
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 8a962888ec0b789a14a24b20c97148e7a8315b30..dfeaf06533e8023cf872763e0f34d98c5dd27a01 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -20,8 +20,8 @@ from src.helpers import TimeTracking
 from src.model_modules.linear_model import OrdinaryLeastSquaredModel
 from src.model_modules.model_class import AbstractModelClass
 from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
-    PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability
-from src.plotting.postprocessing_plotting import plot_conditional_quantiles
+    PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles
+# from src.plotting.postprocessing_plotting import plot_conditional_quantiles
 from src.run_modules.run_environment import RunEnvironment
 
 from typing import Dict
@@ -195,11 +195,8 @@ class PostProcessing(RunEnvironment):
 
         if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list:
             PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN")
-        if "plot_conditional_quantiles" in plot_list:
-            plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="obs",
-                                       forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
-            plot_conditional_quantiles(self.test_data.stations, pred_name="obs", ref_name="CNN",
-                                       forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
+        if "PlotConditionalQuantiles" in plot_list:
+            PlotConditionalQuantiles(self.test_data.stations, data_pred_path=path, plot_folder=self.plot_path)
         if "PlotStationMap" in plot_list:
             PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
         if "PlotMonthlySummary" in plot_list:
diff --git a/test/test_model_modules/test_flatten_tail.py b/test/test_model_modules/test_flatten_tail.py
new file mode 100644
index 0000000000000000000000000000000000000000..0de138ec2323aea3409d5deadfb26c9741b89f50
--- /dev/null
+++ b/test/test_model_modules/test_flatten_tail.py
@@ -0,0 +1,119 @@
+import keras
+import pytest
+from src.model_modules.flatten import flatten_tail, get_activation
+
+
+class TestGetActivation:
+
+    @pytest.fixture()
+    def model_input(self):
+        input_x = keras.layers.Input(shape=(7, 1, 2))
+        return input_x
+
+    def test_string_act(self, model_input):
+        x_in = get_activation(model_input, activation='relu', name='String')
+        act = x_in._keras_history[0]
+        assert act.name == 'String_relu'
+
+    def test_sting_act_unknown(self, model_input):
+        with pytest.raises(ValueError) as einfo:
+            get_activation(model_input, activation='invalid_activation', name='String')
+        assert 'Unknown activation function:invalid_activation' in str(einfo.value)
+
+    def test_layer_act(self, model_input):
+        x_in = get_activation(model_input, activation=keras.layers.advanced_activations.ELU, name='adv_layer')
+        act = x_in._keras_history[0]
+        assert act.name == 'adv_layer'
+
+    def test_layer_act_invalid(self, model_input):
+        with pytest.raises(TypeError) as einfo:
+            get_activation(model_input, activation=keras.layers.Conv2D, name='adv_layer')
+
+
+class TestFlattenTail:
+
+    @pytest.fixture()
+    def model_input(self):
+        input_x = keras.layers.Input(shape=(7, 1, 2))
+        return input_x
+
+    @staticmethod
+    def step_in(element, depth=1):
+        for _ in range(depth):
+            element = element.input._keras_history[0]
+        return element
+
+    def test_flatten_tail_no_bound_no_regul_no_drop(self, model_input):
+        tail = flatten_tail(input_x=model_input, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
+                            output_neurons=2, output_activation='linear',
+                            reduction_filter=None,
+                            name='Main_tail',
+                            bound_weight=False,
+                            dropout_rate=None,
+                            kernel_regularizer=None)
+        final_act = tail._keras_history[0]
+        assert final_act.name == 'Main_tail_final_act_linear'
+        final_dense = self.step_in(final_act)
+        assert final_act.name == 'Main_tail_final_act_linear'
+        assert final_dense.units == 2
+        assert final_dense.kernel_regularizer is None
+        inner_act = self.step_in(final_dense)
+        assert inner_act.name == 'Main_tail_act'
+        assert inner_act.__class__.__name__ == 'ELU'
+        inner_dense = self.step_in(inner_act)
+        assert inner_dense.name == 'Main_tail_inner_Dense'
+        assert inner_dense.units == 64
+        assert inner_dense.kernel_regularizer is None
+        flatten = self.step_in(inner_dense)
+        assert flatten.name == 'Main_tail'
+        input_layer = self.step_in(flatten)
+        assert input_layer.input_shape == (None, 7, 1, 2)
+
+    def test_flatten_tail_all_settings(self, model_input):
+        tail = flatten_tail(input_x=model_input, inner_neurons=64, activation=keras.layers.advanced_activations.ELU,
+                            output_neurons=3, output_activation='linear',
+                            reduction_filter=32,
+                            name='Main_tail_all',
+                            bound_weight=True,
+                            dropout_rate=.35,
+                            kernel_regularizer=keras.regularizers.l2())
+
+        final_act = tail._keras_history[0]
+        assert final_act.name == 'Main_tail_all_final_act_linear'
+
+        final_dense = self.step_in(final_act)
+        assert final_dense.name == 'Main_tail_all_out_Dense'
+        assert final_dense.units == 3
+        assert isinstance(final_dense.kernel_regularizer, keras.regularizers.L1L2)
+
+        final_dropout = self.step_in(final_dense)
+        assert final_dropout.name == 'Main_tail_all_Dropout_2'
+        assert final_dropout.rate == 0.35
+
+        inner_act = self.step_in(final_dropout)
+        assert inner_act.get_config() == {'name': 'activation_1', 'trainable': True, 'activation': 'tanh'}
+
+        inner_dense = self.step_in(inner_act)
+        assert inner_dense.units == 64
+        assert isinstance(inner_dense.kernel_regularizer, keras.regularizers.L1L2)
+
+        inner_dropout = self.step_in(inner_dense)
+        assert inner_dropout.get_config() == {'name': 'Main_tail_all_Dropout_1', 'trainable': True, 'rate': 0.35,
+                                              'noise_shape': None, 'seed': None}
+
+        flatten = self.step_in(inner_dropout)
+        assert flatten.get_config() == {'name': 'Main_tail_all', 'trainable': True, 'data_format': 'channels_last'}
+
+        reduc_act = self.step_in(flatten)
+        assert reduc_act.get_config() == {'name': 'Main_tail_all_conv_act', 'trainable': True, 'alpha': 1.0}
+
+        reduc_conv = self.step_in(reduc_act)
+
+        assert reduc_conv.kernel_size == (1, 1)
+        assert reduc_conv.name == 'Main_tail_all_Conv_1x1'
+        assert reduc_conv.filters == 32
+        assert isinstance(reduc_conv.kernel_regularizer, keras.regularizers.L1L2)
+
+        input_layer = self.step_in(reduc_conv)
+        assert input_layer.input_shape == (None, 7, 1, 2)
+
diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py
index 31c673f05d055eb7c4ee76318711de030d97d480..d3127de1afe0c1691b72dca0408e428fb5944bf4 100644
--- a/test/test_modules/test_training.py
+++ b/test/test_modules/test_training.py
@@ -28,11 +28,19 @@ def my_test_model(activation, window_history_size, channels, dropout_rate, add_m
     X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
     X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
     if add_minor_branch:
-        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
+        # out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
+        out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
+                            output_activation='linear', reduction_filter=64,
+                            name='Minor_1', dropout_rate=dropout_rate,
+                            )]
     else:
         out = []
     X_in = keras.layers.Dropout(dropout_rate)(X_in)
-    out.append(flatten_tail(X_in, 'Main', activation=activation))
+    # out.append(flatten_tail(X_in, 'Main', activation=activation))
+    out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
+                            output_activation='linear', reduction_filter=64,
+                            name='Main', dropout_rate=dropout_rate,
+                            ))
     return keras.Model(inputs=X_input, outputs=out)