diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 50bedd15a438dc9814085f9b5c5c70fb31a71bad..5536f8d6dacb747128270474d392cef2b6d55697 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -462,6 +462,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): targetvar_logical_z_coord_selector=None, rechunk_values=None, date_format_of_nc_file=None, + as_image_like_data_format=True, **kwargs): self.external_coords_file = external_coords_file self.var_logical_z_coord_selector = self._ret_z_coord_select_if_valid(var_logical_z_coord_selector, @@ -474,6 +475,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.__loader = None self.rechunk_values = rechunk_values self.date_format_of_nc_file = date_format_of_nc_file + self.as_image_like_data_format = as_image_like_data_format super().__init__(*args, **kwargs) @staticmethod @@ -571,11 +573,14 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): return res def get_X(self, upsampling=False, as_numpy=False): + x_data = self.get_transposed_history() + if self.as_image_like_data_format is False: + x_data = x_data.squeeze() if as_numpy is True: # return None raise NotImplementedError(f"keyword argument `as_numpy=True' not implemented.") elif as_numpy is False: - return self.get_transposed_history() + return x_data # def get_Y(self, upsampling=False, as_numpy=False): # raise NotImplementedError @@ -608,6 +613,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.remove_nan(self.time_dim) self.history = self.modify_history() + self.label = self.modify_label() self.observation = self.modify_observation() self.remove_nan(self.time_dim) @@ -617,25 +623,31 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # self.target_data = self.target_data.compute() # self._data.close() - def modify_history(self): + def modify_history(self, **kwargs): """ Place holder for more user spec. processing of history. + :param **kwargs: + :type **kwargs: :return: :rtype: """ return self.history - def modify_label(self): + def modify_label(self, **kwargs): """ Place holder for more user spec. processing of label. + :param **kwargs: + :type **kwargs: :return: :rtype: """ return self.label - def modify_observation(self): + def modify_observation(self, **kwargs): """ Place holder for more user spec. processing of observation. + :param **kwargs: + :type **kwargs: :return: :rtype: """ @@ -732,7 +744,7 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): return trafo @TimeTrackingWrapper - def modify_history(self): + def modify_history(self, **kwargs): if self.transformation_is_applied: ws_edges = self.get_applied_transdormation_on_wind_sector_edges() wind_dir_of_interest = self.compute_wind_dir_of_interest()