From a165f726539d4b977146807fb1a321d3bfb9869c Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Thu, 8 Apr 2021 13:41:54 +0200 Subject: [PATCH] manipulate open_mfdataset --- mlair/data_handler/data_handler_wrf_chem.py | 54 +++++++++++++-------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 902d55ff..df1e958d 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -90,10 +90,11 @@ class BaseWrfChemDataLoader: self.start_time = start_time self.end_time = end_time - if rechunk_values is None: - self.rechunk_values = {self.time_dim_name: 1} - else: - self.rechunk_values = rechunk_values + # if rechunk_values is None: + # self.rechunk_values = {self.time_dim_name: 1} + # else: + # self.rechunk_values = rechunk_values + self.rechunk_values = rechunk_values self._stag_ending = stag_ending if staged_dimension_mapping is None: @@ -140,6 +141,9 @@ class BaseWrfChemDataLoader: @TimeTrackingWrapper def open_data(self): + # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575 + # data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, + # parallel=True, decode_cf=False) if self.variables is None: # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575 data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, @@ -148,6 +152,16 @@ class BaseWrfChemDataLoader: data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, parallel=True, decode_cf=False, preprocess=self.preprocess_fkt_for_loader) data = xr.decode_cf(data) + # if self.variables is not None: + # data = self.preprocess_fkt_for_loader(data) + + # if self.rechunk_values is None: + # chunks = {k: 'auto' for k in data.chunks.keys() } + # chunks[self.time_dim_name] = -1 + # data = data.chunk(chunks) + # # data = data.chunk("auto") + # else: + # data = data.chunk(self.rechunk_values) self._data = data def preprocess_fkt_for_loader(self, ds): @@ -160,9 +174,6 @@ class BaseWrfChemDataLoader: set(itertools.chain( *itertools.chain(*SingleGridColumnWrfChemDataLoader.DEFAULT_VARS_TO_ROTATE)))) none_wind_vars_to_keep = [x for x in self.variables if x not in potential_wind_vars_list] - # wind_vars = list(set(self.variables) - set(none_wind_vars_to_keep)) - # wind_vars_to_keep = [wind_var_mapping[i] for i in wind_vars] - # wind_vars_to_keep = list(set(itertools.chain(*wind_vars_to_keep))) wind_vars_to_keep = ['U', 'V', 'U10', 'V10'] combined_vars_to_keep = none_wind_vars_to_keep + wind_vars_to_keep ds = ds[combined_vars_to_keep] @@ -178,7 +189,7 @@ class BaseWrfChemDataLoader: self._data = self._data.assign_coords(coords, **coords_kwargs) def rechunk_data(self, chunks=None, name_prefix='xarray-', token=None, lock=False): - self._data = self._data.chunk(chunks=chunks, name_prefix=name_prefix, token=token, lock=lock) + return self._data.chunk(chunks=chunks, name_prefix=name_prefix, token=token, lock=lock) def read_filenames(self): combs = [] @@ -324,7 +335,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): if self.external_coords_file is not None: self._apply_external_coordinates() self.apply_staged_transormation() - #self.rechunk_data(self.rechunk_values) + # self.rechunk_data(self.rechunk_values) self._set_geoinfos() return self @@ -390,7 +401,8 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): else: raise TypeError(f"`coords' must be a tuple of floats or a dict, but is of type: {type(coords)}") - def get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]: + def \ + get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]: if as_arrays: return np.array(self.__coords['lat']), np.array(self.__coords['lon']) else: @@ -516,15 +528,17 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): data = self.extract_data_from_loader(loader) if self._joint_z_coord_selector is not None: data = data.sel({self._logical_z_coord_name: self._joint_z_coord_selector}) - # expand dimesion for iterdim - data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim) - # transpose dataarray: set first three fixed and keep remaining as is - data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...) - - with ProgressBar(), TimeTracking(name=f"{self.station}: compute data for slice_prep"): - data = dask.compute(self._slice_prep(data, start=start, end=end))[0] - # ToDo add metadata - meta = None + with ProgressBar(), TimeTracking(name=f"{self.station}: get data"): + logging.info(f"start compute data for {self.station} in load_data") + data = dask.compute(data)[0] + + # expand dimesion for iterdim + data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim) + # transpose dataarray: set first three fixed and keep remaining as is + data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...) + data = self._slice_prep(data, start=start, end=end) + # ToDo add metadata + meta = None return data, meta @@ -655,6 +669,8 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): if wind_dir_name not in data: data[wind_dir_name] = full_data[wind_dir_name] self._added_vars.append(to_list(wind_dir_name)) + with ProgressBar(): + data = dask.compute(data)[0] return data def _get_wind_dir_var_name(self, loader, wdir_name3d='wdirll', wdir_name2d='wdir10ll'): -- GitLab