Skip to content
Snippets Groups Projects
Commit 57511d09 authored by leufen1's avatar leufen1
Browse files

also be able to set pooling type

parent fbffafa4
No related branches found
No related tags found
5 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!406Lukas issue368 feat prepare cnn class for filter benchmarking,!403Resolve "prepare CNN class for filter benchmarking"
Pipeline #93689 passed
......@@ -26,9 +26,65 @@ class CNN(AbstractModelClass): # pragma: no cover
_regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
_requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
_dropout = {"selu": keras.layers.AlphaDropout}
_pooling = {"max": keras.layers.MaxPooling2D, "average": keras.layers.AveragePooling2D,
"mean": keras.layers.AveragePooling2D}
"""
Define CNN model as in the following examples:
* use same kernel for all layers and use in total 3 conv layers, no dropout or pooling is applied
```python
model=CNN,
kernel_size=5,
n_layer=3,
dense_layer_configuration=[128, 64],
```
* specify the kernel sizes, make sure len of kernel size parameter matches number of layers
```python
model=CNN,
kernel_size=[3, 7, 11],
n_layer=3,
dense_layer_configuration=[128, 64],
```
* use different number of filters in each layer (can be combined either with fixed or individual kernel sizes),
make sure that lengths match. Using layer_configuration always overwrites any value given to n_layers parameter.
```python
model=CNN,
kernel_size=[3, 7, 11],
layer_configuration=[24, 48, 48],
```
* now specify individual kernel sizes and number of filters for each layer
```python
model=CNN,
layer_configuration=[(16, 3), (32, 7), (64, 11)],
dense_layer_configuration=[128, 64],
```
* add also some dropout and pooling every 2nd layer, dropout is applied after the conv layer, pooling before. Note
that pooling will not used in the init layer whereas dropout is already applied there.
```python
model=CNN,
dropout_freq=2,
dropout=0.3,
pooling_type="max",
pooling_freq=2,
pooling_size=3,
layer_configuration=[(16, 3), (32, 7), (64, 11)],
dense_layer_configuration=[128, 64],
```
"""
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
optimizer="adam", regularizer=None, kernel_size=7, dropout=None, dropout_freq=None, pooling_freq=None,
pooling_type="max",
n_layer=1, n_filter=10, layer_configuration=None, pooling_size=None,
dense_layer_configuration=None, **kwargs):
......@@ -47,6 +103,7 @@ class CNN(AbstractModelClass): # pragma: no cover
self.optimizer = self._set_optimizer(optimizer, **kwargs)
self.layer_configuration = (n_layer, n_filter, self.kernel_size) if layer_configuration is None else layer_configuration
self.dense_layer_configuration = dense_layer_configuration or []
self.pooling = self._set_pooling(pooling_type)
self.pooling_size = pooling_size
self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
self.dropout_freq = self._set_layer_freq(dropout_freq)
......@@ -57,6 +114,12 @@ class CNN(AbstractModelClass): # pragma: no cover
self.set_compile_options()
self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss)
def _set_pooling(self, pooling):
try:
return self._pooling.get(pooling.lower())
except KeyError:
raise AttributeError(f"Given pooling {pooling} is not supported in this model class.")
def _set_layer_freq(self, param):
param = 0 if param is None else param
assert 0 <= param
......@@ -134,7 +197,7 @@ class CNN(AbstractModelClass): # pragma: no cover
x_in = x_input
for layer, (n_filter, kernel_size) in enumerate(conf):
if self.pooling_size is not None and self.pooling_freq > 0 and layer % self.pooling_freq == 0 and layer > 0:
x_in = keras.layers.MaxPooling2D((self.pooling_size, 1), strides=(1, 1), padding='valid')(x_in)
x_in = self.pooling((self.pooling_size, 1), strides=(1, 1), padding='valid')(x_in)
x_in = keras.layers.Conv2D(filters=n_filter, kernel_size=(kernel_size, 1),
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer)(x_in)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment