Skip to content
Snippets Groups Projects
Commit 7673e830 authored by leufen1's avatar leufen1
Browse files

kernel size can be set from outside

parent 7e529068
Branches
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!300include cnn class,!271Resolve "create CNN model class",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #62655 passed
...@@ -22,7 +22,7 @@ class CNN(AbstractModelClass): ...@@ -22,7 +22,7 @@ class CNN(AbstractModelClass):
_requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"] _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
optimizer="adam", regularizer=None, **kwargs): optimizer="adam", regularizer=None, kernel_size=1, **kwargs):
assert len(input_shape) == 1 assert len(input_shape) == 1
assert len(output_shape) == 1 assert len(output_shape) == 1
...@@ -35,6 +35,7 @@ class CNN(AbstractModelClass): ...@@ -35,6 +35,7 @@ class CNN(AbstractModelClass):
self.activation_output_name = activation_output self.activation_output_name = activation_output
self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
self.kernel_size = kernel_size
self.optimizer = self._set_optimizer(optimizer, **kwargs) self.optimizer = self._set_optimizer(optimizer, **kwargs)
# apply to model # apply to model
...@@ -81,7 +82,7 @@ class CNN(AbstractModelClass): ...@@ -81,7 +82,7 @@ class CNN(AbstractModelClass):
Build the model. Build the model.
""" """
x_input = keras.layers.Input(shape=self._input_shape) x_input = keras.layers.Input(shape=self._input_shape)
kernel = (5, 1) kernel = (self.kernel_size, 1)
pad_size = PadUtils.get_padding_for_same(kernel) pad_size = PadUtils.get_padding_for_same(kernel)
x_in = Padding2D("SymPad2D")(padding=pad_size, name="SymPad1")(x_input) x_in = Padding2D("SymPad2D")(padding=pad_size, name="SymPad1")(x_input)
x_in = keras.layers.Conv2D(filters=16, kernel_size=kernel, x_in = keras.layers.Conv2D(filters=16, kernel_size=kernel,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment