Skip to content
Snippets Groups Projects
Commit 029f2144 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

add toy LSTM model

parent 31d83666
Branches
Tags
1 merge request!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
......@@ -477,3 +477,34 @@ class MyPaperModel(AbstractModelClass):
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error],
"metrics": ['mse', 'mae']}
class MyLSTMModel(AbstractModelClass):
def __init__(self, input_shape: list, output_shape: list):
super().__init__(input_shape[0], output_shape[0])
# settings
self.dropout_rate = 0.2
# apply to model
self.set_model()
self.set_compile_options()
self.set_custom_objects(loss=self.compile_options['loss'])
def set_model(self):
x_input = keras.layers.Input(shape=self._input_shape)
x_in = keras.layers.LSTM(32, return_sequences=True, name="First_LSTM")(x_input)
x_in = keras.layers.LSTM(64, name="Second_LSTM")(x_in)
out_main = keras.layers.Dense(self._output_shape, name='Output_Dense')(x_in)
self.model = keras.Model(inputs=x_input, outputs=[out_main])
def set_compile_options(self):
self.initial_lr = 1e-4
self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr,
drop=.94,
epochs_drop=10)
self.loss = keras.losses.mean_squared_error
self.compile_options = {"metrics": ["mse", "mae"]}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment