diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index fd126129a577f6b24200c55b088dd29cc58b0069..82bf3fe17f65be5d2422f8bde7ca42356cf0f67b 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -59,6 +59,7 @@ class BaseWrfChemDataLoader: physical_y_coord_name: str = DEFAULT_PHYSICAL_Y_COORD_NAME, physical_t_coord_name: str = DEFAULT_PHYSICAL_TIME_COORD_NAME, staged_vars: List[str] = DEFAULT_STAGED_VARS, + variables=None, z_coord_selector=None, staged_rotation_opts: Dict = DEFAULT_STAGED_ROTATION_opts, vars_to_rotate: Tuple[Tuple[Tuple[str, str], Tuple[str, str]]] = DEFAULT_VARS_TO_ROTATE, staged_dimension_mapping=None, stag_ending='_stag', @@ -80,6 +81,9 @@ class BaseWrfChemDataLoader: self.staged_rotation_opts = staged_rotation_opts self.vars_to_rotate = vars_to_rotate + self.variables = variables + self.z_coord_selector = z_coord_selector + if rechunk_values is None: self.rechunk_values = {self.time_dim_name: 1} else: @@ -118,12 +122,36 @@ class BaseWrfChemDataLoader: @TimeTrackingWrapper def open_data(self): - # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575 - data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, - parallel=True, decode_cf=False) + if self.variables is None: + # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575 + data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, + parallel=True, decode_cf=False) + else: + data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, + parallel=True, decode_cf=False, preprocess=self.preprocess_fkt_for_loader) data = xr.decode_cf(data) self._data = data + def preprocess_fkt_for_loader(self, ds): + # ToDo make genreal: Currently it's in fact hardcoded! + wind_var_mapping = {'Ull': ['U'], 'Vll': ['V'], 'U10ll': ['U10'], 'V10ll': ['V10'], + 'wspdll': ['U', 'V'], 'wdirll': ['U', 'V'], + 'wspd10ll': ['U10', 'V10'], 'wdir10ll': ['U10', 'V10'], + } + potential_wind_vars_list = list( + set(itertools.chain( + *itertools.chain(*SingleGridColumnWrfChemDataLoader.DEFAULT_VARS_TO_ROTATE)))) + none_wind_vars_to_keep = [x for x in self.variables if x not in potential_wind_vars_list] + # wind_vars = list(set(self.variables) - set(none_wind_vars_to_keep)) + # wind_vars_to_keep = [wind_var_mapping[i] for i in wind_vars] + # wind_vars_to_keep = list(set(itertools.chain(*wind_vars_to_keep))) + wind_vars_to_keep = ['U', 'V', 'U10', 'V10'] + combined_vars_to_keep = none_wind_vars_to_keep + wind_vars_to_keep + ds = ds[combined_vars_to_keep] + if self.z_coord_selector is not None: + ds = ds.sel({self.logical_z_coord_name: self.z_coord_selector}) + return ds + def assign_coords(self, coords, **coords_kwargs): """ Assign coords to WrfChemDataHandler._data @@ -151,12 +179,10 @@ class BaseWrfChemDataLoader: status = 'FAIL' print(f'{i}: {filenamestr} {status}') - @TimeTrackingWrapper def get_distances(self, lat, lon): dist = haversine_dist(lat1=self._data.XLAT, lon1=self._data.XLONG, lat2=lat, lon2=lon) return dist - @TimeTrackingWrapper def get_bearing(self, lat, lon, points_last=True): bearing = bearing_angle(lat1=lat, lon1=lon, lat2=self._data.XLAT, lon2=self._data.XLONG) if points_last: @@ -181,7 +207,6 @@ class BaseWrfChemDataLoader: self._data = data # logging.info('set dimensions as coordinates') - @TimeTrackingWrapper def apply_staged_transormation(self, mapping_of_stag2unstag=None): if mapping_of_stag2unstag is None: mapping_of_stag2unstag = {'U': 'U10', 'V': 'V10', 'U10': 'U10', 'V10': 'V10'} @@ -463,11 +488,17 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): raise ValueError(f"Pass an iterable with two items; (station, path)") lat, lon = self.coord_str2coords(station) with TimeTracking(name="Initialise loader (sgcWRFdh)"): + # preprocess_fkt_for_loader = self.preprocess_fkt_for_loader() + loader = SingleGridColumnWrfChemDataLoader((lat, lon), data_path=path, external_coords_file=self.external_coords_file, time_dim_name=self.time_dim, - rechunk_values=self.rechunk_values + rechunk_values=self.rechunk_values, + variables=self.variables, + z_coord_selector=self._joint_z_coord_selector + + # preprocess_open_mfdataset=preprocess_fkt_for_loader, ) self.__loader = loader @@ -640,7 +671,8 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): @TimeTrackingWrapper def extract_data_from_loader(self, loader): wind_dir_name = self._get_wind_dir_var_name(loader) - full_data = loader.data.isel(loader.get_nearest_icoords()).squeeze() + full_data = loader.data.isel(loader.get_nearest_icoords()).squeeze(dim=[loader.logical_y_coord_name, + loader.logical_x_coord_name]) data = full_data[self.variables] sec_data = self.windsector.get_sect_of_value(full_data[wind_dir_name]) self.sec_data = sec_data