Skip to content
Snippets Groups Projects
Commit 029f2144 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

add toy LSTM model

parent 31d83666
No related branches found
No related tags found
1 merge request!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
......@@ -477,3 +477,34 @@ class MyPaperModel(AbstractModelClass):
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error],
"metrics": ['mse', 'mae']}
class MyLSTMModel(AbstractModelClass):
def __init__(self, input_shape: list, output_shape: list):
super().__init__(input_shape[0], output_shape[0])
# settings
self.dropout_rate = 0.2
# apply to model
self.set_model()
self.set_compile_options()
self.set_custom_objects(loss=self.compile_options['loss'])
def set_model(self):
x_input = keras.layers.Input(shape=self._input_shape)
x_in = keras.layers.LSTM(32, return_sequences=True, name="First_LSTM")(x_input)
x_in = keras.layers.LSTM(64, name="Second_LSTM")(x_in)
out_main = keras.layers.Dense(self._output_shape, name='Output_Dense')(x_in)
self.model = keras.Model(inputs=x_input, outputs=[out_main])
def set_compile_options(self):
self.initial_lr = 1e-4
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr,
drop=.94,
epochs_drop=10)
self.loss = keras.losses.mean_squared_error
self.compile_options = {"metrics": ["mse", "mae"]}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment