From a165f726539d4b977146807fb1a321d3bfb9869c Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 8 Apr 2021 13:41:54 +0200
Subject: [PATCH] manipulate open_mfdataset

---
 mlair/data_handler/data_handler_wrf_chem.py | 54 +++++++++++++--------
 1 file changed, 35 insertions(+), 19 deletions(-)

diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py
index 902d55ff..df1e958d 100644
--- a/mlair/data_handler/data_handler_wrf_chem.py
+++ b/mlair/data_handler/data_handler_wrf_chem.py
@@ -90,10 +90,11 @@ class BaseWrfChemDataLoader:
         self.start_time = start_time
         self.end_time = end_time
 
-        if rechunk_values is None:
-            self.rechunk_values = {self.time_dim_name: 1}
-        else:
-            self.rechunk_values = rechunk_values
+        # if rechunk_values is None:
+        #     self.rechunk_values = {self.time_dim_name: 1}
+        # else:
+        #     self.rechunk_values = rechunk_values
+        self.rechunk_values = rechunk_values
 
         self._stag_ending = stag_ending
         if staged_dimension_mapping is None:
@@ -140,6 +141,9 @@ class BaseWrfChemDataLoader:
 
     @TimeTrackingWrapper
     def open_data(self):
+        # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575
+        # data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name,
+        #                          parallel=True, decode_cf=False)
         if self.variables is None:
             # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575
             data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name,
@@ -148,6 +152,16 @@ class BaseWrfChemDataLoader:
             data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name,
                                      parallel=True, decode_cf=False, preprocess=self.preprocess_fkt_for_loader)
         data = xr.decode_cf(data)
+        # if self.variables is not None:
+        #     data = self.preprocess_fkt_for_loader(data)
+
+        # if self.rechunk_values is None:
+        #     chunks = {k: 'auto' for k in data.chunks.keys() }
+        #     chunks[self.time_dim_name] = -1
+        #     data = data.chunk(chunks)
+        #     # data = data.chunk("auto")
+        # else:
+        #     data = data.chunk(self.rechunk_values)
         self._data = data
 
     def preprocess_fkt_for_loader(self, ds):
@@ -160,9 +174,6 @@ class BaseWrfChemDataLoader:
             set(itertools.chain(
                 *itertools.chain(*SingleGridColumnWrfChemDataLoader.DEFAULT_VARS_TO_ROTATE))))
         none_wind_vars_to_keep = [x for x in self.variables if x not in potential_wind_vars_list]
-        # wind_vars = list(set(self.variables) - set(none_wind_vars_to_keep))
-        # wind_vars_to_keep = [wind_var_mapping[i] for i in wind_vars]
-        # wind_vars_to_keep = list(set(itertools.chain(*wind_vars_to_keep)))
         wind_vars_to_keep = ['U', 'V', 'U10', 'V10']
         combined_vars_to_keep = none_wind_vars_to_keep + wind_vars_to_keep
         ds = ds[combined_vars_to_keep]
@@ -178,7 +189,7 @@ class BaseWrfChemDataLoader:
         self._data = self._data.assign_coords(coords, **coords_kwargs)
 
     def rechunk_data(self, chunks=None, name_prefix='xarray-', token=None, lock=False):
-        self._data = self._data.chunk(chunks=chunks, name_prefix=name_prefix, token=token, lock=lock)
+        return self._data.chunk(chunks=chunks, name_prefix=name_prefix, token=token, lock=lock)
 
     def read_filenames(self):
         combs = []
@@ -324,7 +335,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
         if self.external_coords_file is not None:
             self._apply_external_coordinates()
         self.apply_staged_transormation()
-        #self.rechunk_data(self.rechunk_values)
+        # self.rechunk_data(self.rechunk_values)
         self._set_geoinfos()
         return self
 
@@ -390,7 +401,8 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
         else:
             raise TypeError(f"`coords' must be a tuple of floats or a dict, but is of type: {type(coords)}")
 
-    def get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]:
+    def \
+            get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]:
         if as_arrays:
             return np.array(self.__coords['lat']), np.array(self.__coords['lon'])
         else:
@@ -516,15 +528,17 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
             data = self.extract_data_from_loader(loader)
             if self._joint_z_coord_selector is not None:
                 data = data.sel({self._logical_z_coord_name: self._joint_z_coord_selector})
-            # expand dimesion for iterdim
-            data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim)
-            # transpose dataarray: set first three fixed and keep remaining as is
-            data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...)
-
-            with ProgressBar(), TimeTracking(name=f"{self.station}: compute data for slice_prep"):
-                data = dask.compute(self._slice_prep(data, start=start, end=end))[0]
-            # ToDo add metadata
-            meta = None
+            with ProgressBar(), TimeTracking(name=f"{self.station}: get data"):
+                logging.info(f"start compute data for {self.station} in load_data")
+                data = dask.compute(data)[0]
+
+        # expand dimesion for iterdim
+        data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim)
+        # transpose dataarray: set first three fixed and keep remaining as is
+        data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...)
+        data = self._slice_prep(data, start=start, end=end)
+        # ToDo add metadata
+        meta = None
 
         return data, meta
 
@@ -655,6 +669,8 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn):
         if wind_dir_name not in data:
             data[wind_dir_name] = full_data[wind_dir_name]
             self._added_vars.append(to_list(wind_dir_name))
+        with ProgressBar():
+            data = dask.compute(data)[0]
         return data
 
     def _get_wind_dir_var_name(self, loader, wdir_name3d='wdirll', wdir_name2d='wdir10ll'):
-- 
GitLab