From 7c21ff6a91607d2f2fa558951a2ac415179978ce Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Sat, 10 Apr 2021 11:43:23 +0200 Subject: [PATCH] include boolean 'as_image_like_data_format' --- mlair/data_handler/data_handler_wrf_chem.py | 22 ++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 50bedd15..5536f8d6 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -462,6 +462,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): targetvar_logical_z_coord_selector=None, rechunk_values=None, date_format_of_nc_file=None, + as_image_like_data_format=True, **kwargs): self.external_coords_file = external_coords_file self.var_logical_z_coord_selector = self._ret_z_coord_select_if_valid(var_logical_z_coord_selector, @@ -474,6 +475,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.__loader = None self.rechunk_values = rechunk_values self.date_format_of_nc_file = date_format_of_nc_file + self.as_image_like_data_format = as_image_like_data_format super().__init__(*args, **kwargs) @staticmethod @@ -571,11 +573,14 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): return res def get_X(self, upsampling=False, as_numpy=False): + x_data = self.get_transposed_history() + if self.as_image_like_data_format is False: + x_data = x_data.squeeze() if as_numpy is True: # return None raise NotImplementedError(f"keyword argument `as_numpy=True' not implemented.") elif as_numpy is False: - return self.get_transposed_history() + return x_data # def get_Y(self, upsampling=False, as_numpy=False): # raise NotImplementedError @@ -608,6 +613,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.remove_nan(self.time_dim) self.history = self.modify_history() + self.label = self.modify_label() self.observation = self.modify_observation() self.remove_nan(self.time_dim) @@ -617,25 +623,31 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # self.target_data = self.target_data.compute() # self._data.close() - def modify_history(self): + def modify_history(self, **kwargs): """ Place holder for more user spec. processing of history. + :param **kwargs: + :type **kwargs: :return: :rtype: """ return self.history - def modify_label(self): + def modify_label(self, **kwargs): """ Place holder for more user spec. processing of label. + :param **kwargs: + :type **kwargs: :return: :rtype: """ return self.label - def modify_observation(self): + def modify_observation(self, **kwargs): """ Place holder for more user spec. processing of observation. + :param **kwargs: + :type **kwargs: :return: :rtype: """ @@ -732,7 +744,7 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): return trafo @TimeTrackingWrapper - def modify_history(self): + def modify_history(self, **kwargs): if self.transformation_is_applied: ws_edges = self.get_applied_transdormation_on_wind_sector_edges() wind_dir_of_interest = self.compute_wind_dir_of_interest() -- GitLab