diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 50bedd15a438dc9814085f9b5c5c70fb31a71bad..5536f8d6dacb747128270474d392cef2b6d55697 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -462,6 +462,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): targetvar_logical_z_coord_selector=None, rechunk_values=None, date_format_of_nc_file=None, + as_image_like_data_format=True, **kwargs): self.external_coords_file = external_coords_file self.var_logical_z_coord_selector = self._ret_z_coord_select_if_valid(var_logical_z_coord_selector, @@ -474,6 +475,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.__loader = None self.rechunk_values = rechunk_values self.date_format_of_nc_file = date_format_of_nc_file + self.as_image_like_data_format = as_image_like_data_format super().__init__(*args, **kwargs) @staticmethod @@ -571,11 +573,14 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): return res def get_X(self, upsampling=False, as_numpy=False): + x_data = self.get_transposed_history() + if self.as_image_like_data_format is False: + x_data = x_data.squeeze() if as_numpy is True: # return None raise NotImplementedError(f"keyword argument `as_numpy=True' not implemented.") elif as_numpy is False: - return self.get_transposed_history() + return x_data # def get_Y(self, upsampling=False, as_numpy=False): # raise NotImplementedError @@ -608,6 +613,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.remove_nan(self.time_dim) self.history = self.modify_history() + self.label = self.modify_label() self.observation = self.modify_observation() self.remove_nan(self.time_dim) @@ -617,25 +623,31 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # self.target_data = self.target_data.compute() # self._data.close() - def modify_history(self): + def modify_history(self, **kwargs): """ Place holder for more user spec. processing of history. + :param **kwargs: + :type **kwargs: :return: :rtype: """ return self.history - def modify_label(self): + def modify_label(self, **kwargs): """ Place holder for more user spec. processing of label. + :param **kwargs: + :type **kwargs: :return: :rtype: """ return self.label - def modify_observation(self): + def modify_observation(self, **kwargs): """ Place holder for more user spec. processing of observation. + :param **kwargs: + :type **kwargs: :return: :rtype: """ @@ -732,7 +744,7 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): return trafo @TimeTrackingWrapper - def modify_history(self): + def modify_history(self, **kwargs): if self.transformation_is_applied: ws_edges = self.get_applied_transdormation_on_wind_sector_edges() wind_dir_of_interest = self.compute_wind_dir_of_interest() diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index f8e3a21a81351ac614e2275749bb85fa82a96e02..930b47a98114482059b355a7968f6044df12413a 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -477,3 +477,34 @@ class MyPaperModel(AbstractModelClass): self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error], "metrics": ['mse', 'mae']} + + +class MyLSTMModel(AbstractModelClass): + + def __init__(self, input_shape: list, output_shape: list): + + super().__init__(input_shape[0], output_shape[0]) + + # settings + self.dropout_rate = 0.2 + + # apply to model + self.set_model() + self.set_compile_options() + self.set_custom_objects(loss=self.compile_options['loss']) + + def set_model(self): + x_input = keras.layers.Input(shape=self._input_shape) + x_in = keras.layers.LSTM(32, return_sequences=True, name="First_LSTM")(x_input) + x_in = keras.layers.LSTM(64, name="Second_LSTM")(x_in) + out_main = keras.layers.Dense(self._output_shape, name='Output_Dense')(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out_main]) + + def set_compile_options(self): + self.initial_lr = 1e-4 + self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) + self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, + drop=.94, + epochs_drop=10) + self.loss = keras.losses.mean_squared_error + self.compile_options = {"metrics": ["mse", "mae"]}