Skip to content
Snippets Groups Projects
Commit ef3ced19 authored by leufen1's avatar leufen1
Browse files

added dropout to CNN

parent e5a82238
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!300include cnn class,!271Resolve "create CNN model class",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
...@@ -12,17 +12,22 @@ import keras ...@@ -12,17 +12,22 @@ import keras
class CNN(AbstractModelClass): class CNN(AbstractModelClass):
_activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"), _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"),
"sigmoid": partial(keras.layers.Activation, "sigmoid"), "sigmoid": partial(keras.layers.Activation, "sigmoid"),
"linear": partial(keras.layers.Activation, "linear"), "linear": partial(keras.layers.Activation, "linear"),
"selu": partial(keras.layers.Activation, "selu")} "selu": partial(keras.layers.Activation, "selu"),
_initializer = {"selu": keras.initializers.lecun_normal()} "prelu": partial(keras.layers.PReLU, alpha_initializer=keras.initializers.constant(value=0.25))}
_optimizer = {"adam": keras.optimizers.adam} _initializer = {"tanh": "glorot_uniform", "sigmoid": "glorot_uniform", "linear": "glorot_uniform",
"relu": keras.initializers.he_normal(), "selu": keras.initializers.lecun_normal(),
"prelu": keras.initializers.he_normal()}
_optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD}
_regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2} _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2}
_requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"] _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"]
_dropout = {"selu": keras.layers.AlphaDropout}
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
optimizer="adam", regularizer=None, kernel_size=1, **kwargs): optimizer="adam", regularizer=None, kernel_size=1, dropout=None, **kwargs):
assert len(input_shape) == 1 assert len(input_shape) == 1
assert len(output_shape) == 1 assert len(output_shape) == 1
...@@ -37,6 +42,7 @@ class CNN(AbstractModelClass): ...@@ -37,6 +42,7 @@ class CNN(AbstractModelClass):
self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.optimizer = self._set_optimizer(optimizer, **kwargs) self.optimizer = self._set_optimizer(optimizer, **kwargs)
self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
# apply to model # apply to model
self.set_model() self.set_model()
...@@ -56,6 +62,8 @@ class CNN(AbstractModelClass): ...@@ -56,6 +62,8 @@ class CNN(AbstractModelClass):
opt_kwargs = {} opt_kwargs = {}
if opt_name == "adam": if opt_name == "adam":
opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]) opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
elif opt_name == "sgd":
opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
return opt(**opt_kwargs) return opt(**opt_kwargs)
except KeyError: except KeyError:
raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
...@@ -77,6 +85,12 @@ class CNN(AbstractModelClass): ...@@ -77,6 +85,12 @@ class CNN(AbstractModelClass):
except KeyError: except KeyError:
raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
def _set_dropout(self, activation, dropout_rate):
if dropout_rate is None:
return None, None
assert 0 <= dropout_rate < 1
return self._dropout.get(activation, keras.layers.Dropout), dropout_rate
def set_model(self): def set_model(self):
""" """
Build the model. Build the model.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment