diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py
index 50bedd15a438dc9814085f9b5c5c70fb31a71bad..5536f8d6dacb747128270474d392cef2b6d55697 100644
--- a/mlair/data_handler/data_handler_wrf_chem.py
+++ b/mlair/data_handler/data_handler_wrf_chem.py
@@ -462,6 +462,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
                  targetvar_logical_z_coord_selector=None,
                  rechunk_values=None,
                  date_format_of_nc_file=None,
+                 as_image_like_data_format=True,
                  **kwargs):
         self.external_coords_file = external_coords_file
         self.var_logical_z_coord_selector = self._ret_z_coord_select_if_valid(var_logical_z_coord_selector,
@@ -474,6 +475,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         self.__loader = None
         self.rechunk_values = rechunk_values
         self.date_format_of_nc_file = date_format_of_nc_file
+        self.as_image_like_data_format = as_image_like_data_format
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -571,11 +573,14 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         return res
 
     def get_X(self, upsampling=False, as_numpy=False):
+        x_data = self.get_transposed_history()
+        if self.as_image_like_data_format is False:
+            x_data = x_data.squeeze()
         if as_numpy is True:
             # return None
             raise NotImplementedError(f"keyword argument `as_numpy=True' not implemented.")
         elif as_numpy is False:
-            return self.get_transposed_history()
+            return x_data
 
     # def get_Y(self, upsampling=False, as_numpy=False):
     #     raise NotImplementedError
@@ -608,6 +613,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
 
         self.remove_nan(self.time_dim)
         self.history = self.modify_history()
+
         self.label = self.modify_label()
         self.observation = self.modify_observation()
         self.remove_nan(self.time_dim)
@@ -617,25 +623,31 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         # self.target_data = self.target_data.compute()
         # self._data.close()
 
-    def modify_history(self):
+    def modify_history(self, **kwargs):
         """
         Place holder for more user spec. processing of history.
+        :param **kwargs:
+        :type **kwargs:
         :return:
         :rtype:
         """
         return self.history
 
-    def modify_label(self):
+    def modify_label(self, **kwargs):
         """
         Place holder for more user spec. processing of label.
+        :param **kwargs:
+        :type **kwargs:
         :return:
         :rtype:
         """
         return self.label
 
-    def modify_observation(self):
+    def modify_observation(self, **kwargs):
         """
         Place holder for more user spec. processing of observation.
+        :param **kwargs:
+        :type **kwargs:
         :return:
         :rtype:
         """
@@ -732,7 +744,7 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn):
         return trafo
 
     @TimeTrackingWrapper
-    def modify_history(self):
+    def modify_history(self, **kwargs):
         if self.transformation_is_applied:
             ws_edges = self.get_applied_transdormation_on_wind_sector_edges()
             wind_dir_of_interest = self.compute_wind_dir_of_interest()
diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py
index f8e3a21a81351ac614e2275749bb85fa82a96e02..930b47a98114482059b355a7968f6044df12413a 100644
--- a/mlair/model_modules/model_class.py
+++ b/mlair/model_modules/model_class.py
@@ -477,3 +477,34 @@ class MyPaperModel(AbstractModelClass):
         self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
         self.compile_options = {"loss": [keras.losses.mean_squared_error, keras.losses.mean_squared_error],
                                 "metrics": ['mse', 'mae']}
+
+
+class MyLSTMModel(AbstractModelClass):
+
+    def __init__(self, input_shape: list, output_shape: list):
+
+        super().__init__(input_shape[0], output_shape[0])
+
+        # settings
+        self.dropout_rate = 0.2
+
+        # apply to model
+        self.set_model()
+        self.set_compile_options()
+        self.set_custom_objects(loss=self.compile_options['loss'])
+
+    def set_model(self):
+        x_input = keras.layers.Input(shape=self._input_shape)
+        x_in = keras.layers.LSTM(32, return_sequences=True, name="First_LSTM")(x_input)
+        x_in = keras.layers.LSTM(64, name="Second_LSTM")(x_in)
+        out_main = keras.layers.Dense(self._output_shape, name='Output_Dense')(x_in)
+        self.model = keras.Model(inputs=x_input, outputs=[out_main])
+
+    def set_compile_options(self):
+        self.initial_lr = 1e-4
+        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)
+        self.lr_decay = mlair.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr,
+                                                                               drop=.94,
+                                                                               epochs_drop=10)
+        self.loss = keras.losses.mean_squared_error
+        self.compile_options = {"metrics": ["mse", "mae"]}