From fbffafa41fa12ca20247ce14ef0379b6ff6e05d3 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 1 Mar 2022 16:50:40 +0100 Subject: [PATCH] improved CNN, first try --- mlair/model_modules/convolutional_networks.py | 81 +++++++++++++++++-- 1 file changed, 75 insertions(+), 6 deletions(-) diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index d8eb6eb3..5fd81133 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -17,7 +17,8 @@ class CNN(AbstractModelClass): # pragma: no cover "sigmoid": partial(keras.layers.Activation, "sigmoid"), "linear": partial(keras.layers.Activation, "linear"), "selu": partial(keras.layers.Activation, "selu"), - "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))} + "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25)), + "leakyrelu": partial(keras.layers.LeakyReLU)} _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform", "relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(), "prelu": keras.initializers.he_normal()} @@ -27,7 +28,9 @@ class CNN(AbstractModelClass): # pragma: no cover _dropout = {"selu": keras.layers.AlphaDropout} def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", - optimizer="adam", regularizer=None, kernel_size=1, dropout=None, **kwargs): + optimizer="adam", regularizer=None, kernel_size=7, dropout=None, dropout_freq=None, pooling_freq=None, + n_layer=1, n_filter=10, layer_configuration=None, pooling_size=None, + dense_layer_configuration=None, **kwargs): assert len(input_shape) == 1 assert len(output_shape) == 1 @@ -42,13 +45,24 @@ class CNN(AbstractModelClass): # pragma: no cover self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) self.kernel_size = kernel_size self.optimizer = self._set_optimizer(optimizer, **kwargs) + self.layer_configuration = (n_layer, n_filter, self.kernel_size) if layer_configuration is None else layer_configuration + self.dense_layer_configuration = dense_layer_configuration or [] + self.pooling_size = pooling_size self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) + self.dropout_freq = self._set_layer_freq(dropout_freq) + self.pooling_freq = self._set_layer_freq(pooling_freq) # apply to model self.set_model() self.set_compile_options() self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) + def _set_layer_freq(self, param): + param = 0 if param is None else param + assert 0 <= param + assert isinstance(param, int) + return param + def _set_activation(self, activation): try: return self._activation.get(activation.lower()) @@ -91,6 +105,65 @@ class CNN(AbstractModelClass): # pragma: no cover assert 0 <= dropout_rate < 1 return self._dropout.get(activation, keras.layers.Dropout), dropout_rate + def set_model(self): + """ + Build the model. + """ + if isinstance(self.layer_configuration, tuple) is True: + n_layer, n_hidden, kernel_size = self.layer_configuration + if isinstance(kernel_size, list): + assert len(kernel_size) == n_layer # use individual filter sizes for each layer + conf = [(n_hidden, kernel_size[i]) for i in range(n_layer)] + else: + assert isinstance(kernel_size, int) # use same filter size for all layers + conf = [(n_hidden, kernel_size) for _ in range(n_layer)] + else: + assert isinstance(self.layer_configuration, list) is True + if not isinstance(self.layer_configuration[0], tuple): + if isinstance(self.kernel_size, list): + assert len(self.kernel_size) == len(self.layer_configuration) # use individual filter sizes for each layer + conf = [(n_filter, self.kernel_size[i]) for i, n_filter in enumerate(self.layer_configuration)] + else: + assert isinstance(self.kernel_size, int) # use same filter size for all layers + conf = [(n_filter, self.kernel_size) for n_filter in self.layer_configuration] + else: + assert len(self.layer_configuration[0]) == 2 + conf = self.layer_configuration + + x_input = keras.layers.Input(shape=self._input_shape) + x_in = x_input + for layer, (n_filter, kernel_size) in enumerate(conf): + if self.pooling_size is not None and self.pooling_freq > 0 and layer % self.pooling_freq == 0 and layer > 0: + x_in = keras.layers.MaxPooling2D((self.pooling_size, 1), strides=(1, 1), padding='valid')(x_in) + x_in = keras.layers.Conv2D(filters=n_filter, kernel_size=(kernel_size, 1), + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer)(x_in) + x_in = self.activation()(x_in) + if self.dropout is not None and self.dropout_freq > 0 and layer % self.dropout_freq == 0: + x_in = self.dropout(self.dropout_rate)(x_in) + + x_in = keras.layers.Flatten()(x_in) + for layer, n_hidden in enumerate(self.dense_layer_configuration): + if n_hidden < self._output_shape: + break + x_in = keras.layers.Dense(n_hidden, name=f"Dense_{len(conf) + layer + 1}", + kernel_initializer=self.kernel_initializer, )(x_in) + x_in = self.activation(name=f"{self.activation_name}_{len(conf) + layer + 1}")(x_in) + if self.dropout is not None: + x_in = self.dropout(self.dropout_rate)(x_in) + + x_in = keras.layers.Dense(self._output_shape)(x_in) + out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out]) + print(self.model.summary()) + + def set_compile_options(self): + self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], + "metrics": ["mse", "mae", var_loss]} + + +class CNN_16_32_64(CNN): + def set_model(self): """ Build the model. @@ -123,7 +196,3 @@ class CNN(AbstractModelClass): # pragma: no cover x_in = keras.layers.Dense(self._output_shape)(x_in) out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) self.model = keras.Model(inputs=x_input, outputs=[out]) - - def set_compile_options(self): - self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], - "metrics": ["mse", "mae", var_loss]} -- GitLab