From f384b44849119cf321693f0e3c12e5ce37befda2 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 26 May 2021 13:59:35 +0200
Subject: [PATCH] plotting works better but is still too messy, improved rnn
 class

---
 mlair/helpers/filter.py                       | 233 ++++++++++++++----
 .../model_modules/fully_connected_networks.py |  55 +++--
 mlair/model_modules/recurrent_networks.py     |  95 ++++---
 3 files changed, 286 insertions(+), 97 deletions(-)

diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py
index a4288098..041e63c9 100644
--- a/mlair/helpers/filter.py
+++ b/mlair/helpers/filter.py
@@ -78,6 +78,7 @@ class ClimateFIRFilter:
         logging.info(f"{plot_name}: start init ClimateFIRFilter")
         self.plot_path = plot_path
         self.plot_name = plot_name
+        self.plot_data = []
         filtered = []
         h = []
         sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts}
@@ -96,6 +97,9 @@ class ClimateFIRFilter:
         apriori_list = to_list(apriori)
         input_data = data.__deepcopy__()
 
+        # for viz
+        plot_dates = None
+
         # create tmp dimension to apply filter, search for unused name
         new_dim = self._create_tmp_dimension(input_data) if new_dim is None else new_dim
 
@@ -106,11 +110,13 @@ class ClimateFIRFilter:
             # ToDo: remove all methods except the vectorized version
             clim_filter: Callable = {True: self.clim_filter_vectorized_less_memory, False: self.clim_filter}[vectorized]
             _minimum_length = self._minimum_length(order, minimum_length, i)
-            fi, hi, apriori = clim_filter(input_data, fs, cutoff[i], order[i],
-                                          apriori=apriori_list[i],
-                                          sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, window=window,
-                                          var_dim=var_dim, plot_index=i, padlen_factor=padlen_factor,
-                                          minimum_length=_minimum_length, new_dim=new_dim)
+            fi, hi, apriori, plot_data = clim_filter(input_data, fs, cutoff[i], order[i],
+                                                     apriori=apriori_list[i],
+                                                     sel_opts=sel_opts, sampling=sampling, time_dim=time_dim,
+                                                     window=window,
+                                                     var_dim=var_dim, plot_index=i, padlen_factor=padlen_factor,
+                                                     minimum_length=_minimum_length, new_dim=new_dim,
+                                                     plot_dates=plot_dates)
 
             logging.info(f"{plot_name}: finished clim_filter calculation")
             if minimum_length is None:
@@ -119,6 +125,8 @@ class ClimateFIRFilter:
                 filtered.append(fi.sel({new_dim: slice(-minimum_length, 0)}))
             h.append(hi)
             gc.collect()
+            self.plot_data.append(plot_data)
+            plot_dates = {e["t0"] for e in plot_data}
 
             # calculate residuum
             logging.info(f"{plot_name}: calculate residuum")
@@ -158,6 +166,9 @@ class ClimateFIRFilter:
         self._h = h
         self._apriori = apriori_list
 
+        # visualize
+        self._plot(sampling)
+
     @staticmethod
     def _minimum_length(order, minimum_length, pos):
         next_order = 0
@@ -510,7 +521,7 @@ class ClimateFIRFilter:
     def clim_filter_vectorized_less_memory(self, data, fs, cutoff_high, order, apriori=None, padlen_factor=0.5,
                                            sel_opts=None,
                                            sampling="1d", time_dim="datetime", var_dim="variables", window="hamming",
-                                           plot_index=None, minimum_length=None, new_dim="window"):
+                                           plot_index=None, minimum_length=None, new_dim="window", plot_dates=None):
 
         logging.info(f"{data.coords['Stations'].values[0]}: extend apriori")
 
@@ -533,8 +544,10 @@ class ClimateFIRFilter:
 
         # collect some data for visualization
         plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * fs
-        plot_dates = [data.isel({time_dim: int(pos)}).coords[time_dim].values for pos in plot_pos if
-                      pos < len(data.coords[time_dim])]
+        if plot_dates is None:
+            plot_dates = [data.isel({time_dim: int(pos)}).coords[time_dim].values for pos in plot_pos if
+                          pos < len(data.coords[time_dim])]
+        plot_data = []
 
         coll = []
 
@@ -542,9 +555,6 @@ class ClimateFIRFilter:
             # self._tmp_analysis(data, apriori, var, var_dim, length, time_dim, new_dim, h)
             logging.info(f"{data.coords['Stations'].values[0]} ({var}): sel data")
 
-            # empty plot data collection
-            plot_data = []
-
             _start = pd.to_datetime(data.coords[time_dim].min().values).year
             _end = pd.to_datetime(data.coords[time_dim].max().values).year
             filt_coll = []
@@ -579,14 +589,6 @@ class ClimateFIRFilter:
                     continue
                 if len(filter_input_data.coords[time_dim]) == 0:  # no valid data for this year
                     continue
-                # filter_input_data = history.combine_first(future)
-                # history.sel(datetime=slice("2010-11-01", "2011-04-01"),variables="o3").plot()
-                # filter_input_data.sel(datetime=slice("2009-11-01", "2011-04-01"),variables="temp").plot()
-                # ToDo: remove all other filt methods, only keep the convolve one
-                # time_axis = filter_input_data.coords[time_dim]
-                # # apply vectorized fir filter along the tmp dimension
-                # kwargs = {"fs": fs, "cutoff_high": cutoff_high, "order": order,
-                #           "causal": False, "padlen": int(min(padlen_factor, 1) * length), "h": h}
 
                 logging.info(f"{data.coords['Stations'].values[0]} ({var}): start filter convolve")
                 with TimeTracking(name="convolve"):
@@ -604,37 +606,35 @@ class ClimateFIRFilter:
                     filt_coll.append(filt.sel({new_dim: slice(-minimum_length, 0)}, drop=True))
 
                 # visualization
-                # ToDo: move this code part into a separate plot method that is called on the fly, not afterwards
-                # just leave a call self.plot(*args) here!
-                # for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values):
-                #     td_type = {"1d": "D", "1H": "h"}.get(sampling)
-                #     t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type)
-                #     t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type)
-                #     if new_dim not in d.coords:
-                #         tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}),
-                #                                            range(int(-extend_length_history),
-                #                                                  int(extend_length_future)),
-                #                                            time_dim, var_dim, new_dim)
-                #     else:
-                #         tmp_filter_data = d.sel({new_dim: slice(int(-extend_length_history), int(extend_length_future))})
-                #     tmp_filt_nc = xr.apply_ufunc(fir_filter_convolve_vectorized,
-                #                                  tmp_filter_data.sel({time_dim: viz_date}),
-                #                                  input_core_dims=[[new_dim]],
-                #                                  output_core_dims=[[new_dim]],
-                #                                  vectorize=True,
-                #                                  kwargs={"h": h},
-                #                                  output_dtypes=[d.dtype])
-                #
-                #     valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1)
-                #     plot_data.append({"t0": viz_date,
-                #                       "filt": filt.sel({time_dim: viz_date}),
-                #                       "filter_input": filter_input_data.sel({time_dim: viz_date}),
-                #                       "filt_nc": tmp_filt_nc,
-                #                       "valid_range": valid_range})
-
-            # select only values at tmp dimension 0 at each point in time
-            # coll.append(filt.sel({new_dim: 0}, drop=True))
-            # coll.append(filt.sel({new_dim: slice(-extend_length, 0)}, drop=True))
+                for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values):
+                    try:
+                        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+                        t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type)
+                        t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type)
+                        if new_dim not in d.coords:
+                            tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}),
+                                                               range(int(-extend_length_history),
+                                                                     int(extend_length_future)),
+                                                               time_dim, var_dim, new_dim).sel({time_dim: viz_date})
+                        else:
+                            # tmp_filter_data = d.sel({time_dim: viz_date,
+                            #                          new_dim: slice(int(-extend_length_history), int(extend_length_future))})
+                            tmp_filter_data = None
+                        valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1)
+                        plot_data.append({"t0": viz_date,
+                                          "var": var,
+                                          "filter_input": filter_input_data.sel({time_dim: viz_date}),
+                                          "filter_input_nc": tmp_filter_data,
+                                          "valid_range": valid_range,
+                                          "time_range": d.sel(
+                                              {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[
+                                              time_dim].values,
+                                          "h": h,
+                                          "new_dim": new_dim})
+                    except:
+                        pass
+
+            # collect all filter results
             coll.append(xr.concat(filt_coll, time_dim))
             gc.collect()
 
@@ -671,7 +671,7 @@ class ClimateFIRFilter:
         # res_full.loc[res.coords] = res
         # res_full.compute()
         res_full = res.broadcast_like(xr.DataArray(dims=dims, coords=new_coords))
-        return res_full, h, apriori
+        return res_full, h, apriori, plot_data
 
     @staticmethod
     def _create_time_range_extend(year, sampling, extend_length):
@@ -708,6 +708,137 @@ class ClimateFIRFilter:
         res.name = index_name
         return res
 
+    def _plot(self, sampling):
+        new_dim = "window"
+        h = None
+        td_type = {"1d": "D", "1H": "h"}.get(sampling)
+        if self.plot_path is None:
+            return
+        plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR")
+        if not os.path.exists(plot_folder):
+            os.makedirs(plot_folder)
+
+        rc_params = {'axes.labelsize': 'large',
+                     'xtick.labelsize': 'large',
+                     'ytick.labelsize': 'large',
+                     'legend.fontsize': 'large',
+                     'axes.titlesize': 'large',
+                     }
+        plt.rcParams.update(rc_params)
+
+        plot_dict = {}
+        for i, o in enumerate(range(len(self.plot_data))):
+            plot_data = self.plot_data[i]
+            for p_d in plot_data:
+                var = p_d.get("var")
+                t0 = p_d.get("t0")
+                filter_input = p_d.get("filter_input")
+                filter_input_nc = p_d.get("filter_input_nc")
+                valid_range = p_d.get("valid_range")
+                time_range = p_d.get("time_range")
+                new_dim = p_d.get("new_dim")
+                h = p_d.get("h")
+                plot_dict_var = plot_dict.get(var, {})
+                plot_dict_t0 = plot_dict_var.get(t0, {})
+                plot_dict_order = {"filter_input": filter_input,
+                                   "filter_input_nc": filter_input_nc,
+                                   "valid_range": valid_range,
+                                   "time_range": time_range,
+                                   "order": o, "h": h}
+                plot_dict_t0[i] = plot_dict_order
+                plot_dict_var[t0] = plot_dict_t0
+                plot_dict[var] = plot_dict_var
+
+        for var, viz_date_dict in plot_dict.items():
+            for it0, t0 in enumerate(viz_date_dict.keys()):
+                viz_data = viz_date_dict[t0]
+                residuum_true = None
+                for ifilter in sorted(viz_data.keys()):
+                    data = viz_data[ifilter]
+                    filter_input = data["filter_input"]
+                    filter_input_nc = data["filter_input_nc"] if residuum_true is None else residuum_true.sel(
+                        {new_dim: filter_input.coords[new_dim]})
+                    valid_range = data["valid_range"]
+                    time_axis = data["time_range"]
+                    # time_axis = pd.date_range(t_minus, t_plus, freq=sampling)
+                    filter_order = data["order"]
+                    h = data["h"]
+                    t_minus = t0 + np.timedelta64(-int(1.5 * valid_range.start), td_type)
+                    t_plus = t0 + np.timedelta64(int(0.5 * valid_range.start), td_type)
+                    fig, ax = plt.subplots()
+                    ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type),
+                               t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke",
+                               label="valid area")
+                    ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)")
+
+                    # original data
+                    ax.plot(time_axis, filter_input_nc.values.flatten(), color="darkgrey", linestyle="dashed",
+                            label="original")
+
+                    # clim apriori
+                    if ifilter == 0:
+                        d_tmp = filter_input.sel(
+                            {new_dim: slice(0, filter_input.coords[new_dim].values.max())}).values.flatten()
+                    else:
+                        d_tmp = filter_input.values.flatten()
+                    ax.plot(time_axis[len(time_axis) - len(d_tmp):], d_tmp, color="darkgrey", linestyle="solid",
+                            label="estimated future")
+
+                    # clim filter response
+                    filt = xr.apply_ufunc(fir_filter_convolve_vectorized, filter_input,
+                                          input_core_dims=[[new_dim]],
+                                          output_core_dims=[[new_dim]],
+                                          vectorize=True,
+                                          kwargs={"h": h},
+                                          output_dtypes=[filter_input.dtype])
+                    ax.plot(time_axis, filt.values.flatten(), color="black", linestyle="solid",
+                            label="clim filter response", linewidth=2)
+                    residuum_estimated = filter_input - filt
+
+                    # ideal filter response
+                    filt = xr.apply_ufunc(fir_filter_convolve_vectorized, filter_input_nc,
+                                          input_core_dims=[[new_dim]],
+                                          output_core_dims=[[new_dim]],
+                                          vectorize=True,
+                                          kwargs={"h": h},
+                                          output_dtypes=[filter_input.dtype])
+                    ax.plot(time_axis, filt.values.flatten(), color="black", linestyle="dashed",
+                            label="ideal filter response", linewidth=2)
+                    residuum_true = filter_input_nc - filt
+
+                    # set title, legend, and save plot
+                    ax_start = max(t_minus, time_axis[0])
+                    ax_end = min(t_plus, time_axis[-1])
+                    ax.set_xlim((ax_start, ax_end))
+                    plt.title(f"Input of ClimFilter ({str(var)})")
+                    plt.legend()
+                    fig.autofmt_xdate()
+                    plt.tight_layout()
+                    plot_name = os.path.join(plot_folder,
+                                             f"climFIR_{self.plot_name}_{str(var)}_{it0}_{ifilter}.pdf")
+                    plt.savefig(plot_name, dpi=300)
+                    plt.close('all')
+
+                    # plot residuum
+                    fig, ax = plt.subplots()
+                    ax.axvspan(t0 + np.timedelta64(-valid_range.start, td_type),
+                               t0 + np.timedelta64(valid_range.stop - 1, td_type), color="whitesmoke",
+                               label="valid area")
+                    ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)")
+                    ax.plot(time_axis, residuum_true.values.flatten(), color="black", linestyle="dashed",
+                            label="ideal filter residuum", linewidth=2)
+                    ax.plot(time_axis, residuum_estimated.values.flatten(), color="black", linestyle="solid",
+                            label="clim filter residuum", linewidth=2)
+                    ax.set_xlim((ax_start, ax_end))
+                    plt.title(f"Residuum of ClimFilter ({str(var)})")
+                    plt.legend()
+                    fig.autofmt_xdate()
+                    plt.tight_layout()
+                    plot_name = os.path.join(plot_folder,
+                                             f"climFIR_{self.plot_name}_{str(var)}_{it0}_{ifilter}_residuum.pdf")
+                    plt.savefig(plot_name, dpi=300)
+                    plt.close('all')
+
     def plot_new(self, viz_data, orig_data, var_dim, time_dim, new_dim, plot_index, sampling):
         try:
             td_type = {"1d": "D", "1H": "h"}.get(sampling)
diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py
index 009ff060..ff06f075 100644
--- a/mlair/model_modules/fully_connected_networks.py
+++ b/mlair/model_modules/fully_connected_networks.py
@@ -20,7 +20,8 @@ class FCN(AbstractModelClass):
                    "sigmoid": partial(keras.layers.Activation, "sigmoid"),
                    "linear": partial(keras.layers.Activation, "linear"),
                    "selu": partial(keras.layers.Activation, "selu"),
-                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))}
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
+                   "leakyrelu": partial(keras.layers.LeakyReLU)}
     _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
                     "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
                     "prelu": keras.initializers.he_normal()}
@@ -31,12 +32,31 @@ class FCN(AbstractModelClass):
 
     def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
                  optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
-                 **kwargs):
+                 batch_normalization=False, **kwargs):
         """
         Sets model and loss depending on the given arguments.
 
         :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
         :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+
+        Customize this FCN model via the following parameters:
+
+        :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
+            leakyrelu. (Default relu)
+        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
+            linear)
+        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
+        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
+            layer. (Default 1)
+        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
+        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
+            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
+            hidden layer. The number of hidden layers is equal to the total length of this list.
+        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
+            network at all. (Default None)
+        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
+            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
+            is added if set to false. (Default false)
         """
 
         assert len(input_shape) == 1
@@ -49,6 +69,7 @@ class FCN(AbstractModelClass):
         self.activation_output = self._set_activation(activation_output)
         self.activation_output_name = activation_output
         self.optimizer = self._set_optimizer(optimizer, **kwargs)
+        self.bn = batch_normalization
         self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
         self._update_model_name()
         self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
@@ -115,27 +136,29 @@ class FCN(AbstractModelClass):
         """
         Build the model.
         """
-        x_input = keras.layers.Input(shape=self._input_shape)
-        x_in = keras.layers.Flatten()(x_input)
         if isinstance(self.layer_configuration, tuple) is True:
             n_layer, n_hidden = self.layer_configuration
-            for layer in range(n_layer):
-                x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
-                                          kernel_regularizer=self.kernel_regularizer)(x_in)
-                x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
-                if self.dropout is not None:
-                    x_in = self.dropout(self.dropout_rate)(x_in)
+            conf = [n_hidden for _ in range(n_layer)]
         else:
             assert isinstance(self.layer_configuration, list) is True
-            for layer, n_hidden in enumerate(self.layer_configuration):
-                x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
-                                          kernel_regularizer=self.kernel_regularizer)(x_in)
-                x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
-                if self.dropout is not None:
-                    x_in = self.dropout(self.dropout_rate)(x_in)
+            conf = self.layer_configuration
+
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.Flatten()(x_input)
+
+        for layer, n_hidden in enumerate(conf):
+            x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer,
+                                      kernel_regularizer=self.kernel_regularizer)(x_in)
+            if self.bn is True:
+                x_in = keras.layers.BatchNormalization()(x_in)
+            x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
+            if self.dropout is not None:
+                x_in = self.dropout(self.dropout_rate)(x_in)
+
         x_in = keras.layers.Dense(self._output_shape)(x_in)
         out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out])
+        print(self.model.summary())
 
     def set_compile_options(self):
         self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])],
diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py
index 953749c3..55a3d585 100644
--- a/mlair/model_modules/recurrent_networks.py
+++ b/mlair/model_modules/recurrent_networks.py
@@ -19,7 +19,8 @@ class RNN(AbstractModelClass):
                    "sigmoid": partial(keras.layers.Activation, "sigmoid"),
                    "linear": partial(keras.layers.Activation, "linear"),
                    "selu": partial(keras.layers.Activation, "selu"),
-                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))}
+                   "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)),
+                   "leakyrelu": partial(keras.layers.LeakyReLU)}
     _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
                     "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
                     "prelu": keras.initializers.he_normal()}
@@ -27,15 +28,37 @@ class RNN(AbstractModelClass):
     _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
     _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
     _dropout = {"selu": keras.layers.AlphaDropout}
+    _rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU}
 
     def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
                  optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
-                 **kwargs):
+                 batch_normalization=False, rnn_type="lstm", **kwargs):
         """
         Sets model and loss depending on the given arguments.
 
         :param input_shape: list of input shapes (expect len=1 with shape=(window_hist, station, variables))
         :param output_shape: list of output shapes (expect len=1 with shape=(window_forecast))
+
+        Customize this RNN model via the following parameters:
+
+        :param activation: set your desired activation function. Chose from relu, tanh, sigmoid, linear, selu, prelu,
+            leakyrelu. (Default relu)
+        :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default
+            linear)
+        :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam)
+        :param n_layer: define number of hidden layers in the network. Given number of hidden neurons are used in each
+            layer. (Default 1)
+        :param n_hidden: define number of hidden units per layer. This number is used in each hidden layer. (Default 10)
+        :param layer_configuration: alternative formulation of the network's architecture. This will overwrite the
+            settings from n_layer and n_hidden. Provide a list where each element represent the number of units in the
+            hidden layer. The number of hidden layers is equal to the total length of this list.
+        :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the
+            network at all. (Default None)
+        :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted
+            between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer
+            is added if set to false. (Default false)
+        :param rnn_type: define which kind of recurrent network should be applied. Chose from either lstm or gru. All
+            units will be of this kind. (Default lstm)
         """
 
         assert len(input_shape) == 1
@@ -43,13 +66,15 @@ class RNN(AbstractModelClass):
         super().__init__(input_shape[0], output_shape[0])
 
         # settings
-        # self.activation = self._set_activation(activation)
-        # self.activation_name = activation
-        self.activation_output = self._set_activation(activation_output)
+        self.activation = self._set_activation(activation.lower())
+        self.activation_name = activation
+        self.activation_output = self._set_activation(activation_output.lower())
         self.activation_output_name = activation_output
-        self.optimizer = self._set_optimizer(optimizer, **kwargs)
-        # self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
-        # self._update_model_name()
+        self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
+        self.bn = batch_normalization
+        self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
+        self.RNN = self._rnn.get(rnn_type.lower())
+        self._update_model_name(rnn_type)
         # self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
         # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
         self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
@@ -63,29 +88,40 @@ class RNN(AbstractModelClass):
         """
         Build the model.
         """
+        if isinstance(self.layer_configuration, tuple) is True:
+            n_layer, n_hidden = self.layer_configuration
+            conf = [n_hidden for _ in range(n_layer)]
+        else:
+            assert isinstance(self.layer_configuration, list) is True
+            conf = self.layer_configuration
+
         x_input = keras.layers.Input(shape=self._input_shape)
         x_in = keras.layers.Reshape((self._input_shape[0], reduce((lambda x, y: x * y), self._input_shape[1:])))(
             x_input)
-        x_in = keras.layers.LSTM(32, return_sequences=True)(x_in)
-        if self.dropout is not None:
-            x_in = self.dropout(self.dropout_rate)(x_in)
-        x_in = keras.layers.LSTM(8)(x_in)
-        if self.dropout is not None:
-            x_in = self.dropout(self.dropout_rate)(x_in)
-        out = keras.layers.Dense(self._output_shape)(x_in)
+
+        for layer, n_hidden in enumerate(conf):
+            return_sequences = (layer < len(conf) - 1)
+            x_in = self.RNN(n_hidden, return_sequences=return_sequences)(x_in)
+            if self.bn is True:
+                x_in = keras.layers.BatchNormalization()(x_in)
+            x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
+            if self.dropout is not None:
+                x_in = self.dropout(self.dropout_rate)(x_in)
+
+        x_in = keras.layers.Dense(self._output_shape)(x_in)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
         self.model = keras.Model(inputs=x_input, outputs=[out])
         print(self.model.summary())
 
-        # x_input = keras.layers.Input(shape=self._input_shape)
-        # x_in = keras.layers.Reshape((self._input_shape[0], reduce((lambda x, y: x * y), self._input_shape[1:])))(
-        #     x_input)
         # x_in = keras.layers.LSTM(32)(x_in)
+        # if self.dropout is not None:
+        #     x_in = self.dropout(self.dropout_rate)(x_in)
         # x_in = keras.layers.RepeatVector(self._output_shape)(x_in)
         # x_in = keras.layers.LSTM(32, return_sequences=True)(x_in)
+        # if self.dropout is not None:
+        #     x_in = self.dropout(self.dropout_rate)(x_in)
         # out = keras.layers.TimeDistributed(keras.layers.Dense(1))(x_in)
         # out = keras.layers.Flatten()(out)
-        # self.model = keras.Model(inputs=x_input, outputs=[out])
-        # print(self.model.summary())
 
     def _set_dropout(self, activation, dropout_rate):
         if dropout_rate is None:
@@ -134,13 +170,12 @@ class RNN(AbstractModelClass):
     #         raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
     #
 
-    #
-    # def _update_model_name(self):
-    #     n_input = str(reduce(lambda x, y: x * y, self._input_shape))
-    #     n_output = str(self._output_shape)
-    #     if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
-    #         n_layer, n_hidden = self.layer_configuration
-    #         self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output])
-    #     else:
-    #         self.model_name += "_".join(["", n_input, *[f"{n}" for n in self.layer_configuration], n_output])
-    #
+    def _update_model_name(self, rnn_type):
+        n_input = str(reduce(lambda x, y: x * y, self._input_shape))
+        n_output = str(self._output_shape)
+        self.model_name = rnn_type.upper()
+        if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
+            n_layer, n_hidden = self.layer_configuration
+            self.model_name += "_".join(["", n_input, *[f"{n_hidden}" for _ in range(n_layer)], n_output])
+        else:
+            self.model_name += "_".join(["", n_input, *[f"{n}" for n in self.layer_configuration], n_output])
-- 
GitLab