diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 57c8d67825947ee27fbebc4393e369f2e6a092e3..c1c2d0c5499ea617c7d3c01ff694e75378f855ff 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -9,13 +9,13 @@ import matplotlib.pyplot as plt import dask import inspect import os -from mlair.helpers.geofunctions import haversine_dist, bearing_angle, WindSector -from mlair.helpers.helpers import convert2xrda, remove_items +from mlair.helpers.geofunctions import haversine_dist, bearing_angle, WindSector, VectorRotateLambertConformal2latlon +from mlair.helpers.helpers import convert2xrda, remove_items, to_list from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler import DefaultDataHandler -from typing import Tuple, Union, Dict +from typing import Tuple, Union, Dict, List import logging import cartopy.crs as ccrs @@ -40,6 +40,14 @@ class BaseWrfChemDataLoader: DEFAULT_RECHUNK = {"Time": -1, "y": 36, "x": 40} DEFAULT_FILE_STARTER = 'wrfout_d0' + DEFAULT_STAGED_VARS = ('U', 'V') + DEFAULT_STAGED_ROTATION_opts = dict(cen_lon=12., cen_lat=52.5, + truelat1=30., truelat2=60., + stand_lon=12.) + DEFAULT_VARS_TO_ROTATE = ((('U', 'V'), ('Ull', 'Vll')), (('U10', 'V10'), ('U10ll', 'V10ll')), + (('U', 'V'), ('wspdll', 'wdirll')), (('U10', 'V10'), ('wspd10ll', 'wdir10ll')) + ) + def __init__(self, data_path: str, common_file_starter: str = DEFAULT_FILE_STARTER, time_dim_name: str = DEFAULT_LOGICAL_TIME_COORD_NAME, rechunk_values: Dict = None, @@ -48,7 +56,12 @@ class BaseWrfChemDataLoader: logical_z_coord_name: str = DEFAULT_LOGICAL_Z_COORD_NAME, physical_x_coord_name: str = DEFAULT_PHYSICAL_X_COORD_NAME, physical_y_coord_name: str = DEFAULT_PHYSICAL_Y_COORD_NAME, - physical_t_coord_name: str = DEFAULT_PHYSICAL_TIME_COORD_NAME + physical_t_coord_name: str = DEFAULT_PHYSICAL_TIME_COORD_NAME, + staged_vars: List[str] = DEFAULT_STAGED_VARS, + staged_rotation_opts: Dict = DEFAULT_STAGED_ROTATION_opts, + vars_to_rotate: Tuple[Tuple[Tuple[str, str], Tuple[str, str]]] = DEFAULT_VARS_TO_ROTATE, + staged_dimension_mapping = None, stag_ending = '_stag', + ): super().__init__() self.data_path = data_path @@ -63,16 +76,38 @@ class BaseWrfChemDataLoader: self.physical_y_coord_name = physical_y_coord_name self.physical_t_coord_name = physical_t_coord_name + self.staged_vars = to_list(staged_vars) + self.staged_rotation_opts = staged_rotation_opts + self.vars_to_rotate = vars_to_rotate + if rechunk_values is None: self.rechunk_values = {self.time_dim_name: 1} else: self.rechunk_values = rechunk_values + self._stag_ending = stag_ending + if staged_dimension_mapping is None: + self.staged_dimension_mapping = {'X': self.logical_x_coord_name+self._stag_ending, + 'Y': self.logical_y_coord_name+self._stag_ending, + 'Z': self.logical_z_coord_name+self._stag_ending,} + else: + self.staged_dimension_mapping = staged_dimension_mapping + # internal self._X = None self._Y = None self._data = None + @property + def staged_dimension_mapping(self): + return self.__staged_dimension_mapping + + @staged_dimension_mapping.setter + def staged_dimension_mapping(self, var: dict): + if not isinstance(var, dict): + raise TypeError(f"`var' must be a dict, but is of type {type(var)}") + self.__staged_dimension_mapping = var + @property def data(self) -> xr.Dataset: return self._data @@ -149,6 +184,65 @@ class BaseWrfChemDataLoader: self._data = data logging.info('set dimensions as coordinates') + def apply_staged_transormation(self, mapping_of_stag2unstag=None): + if mapping_of_stag2unstag is None: + mapping_of_stag2unstag = {'U': 'U10', 'V': 'V10', 'U10': 'U10', 'V10': 'V10'} + vectorrot = VectorRotateLambertConformal2latlon(xlat=self._data[self.physical_y_coord_name].data, + xlong=self._data[self.physical_x_coord_name].data, + **self.staged_rotation_opts) + if self._data is None: + raise IOError(f"`open_data' must be called before vector interpolation and rotation") + for (u_stag_name, v_stag_name), (u_ll_name, v_ll_name) in self.vars_to_rotate: + u_staged_field = self._data[u_stag_name] + v_staged_field = self._data[v_stag_name] + + u_target_field = self._data[mapping_of_stag2unstag[u_stag_name]] + v_target_field = self._data[mapping_of_stag2unstag[v_stag_name]] + + u_stagger = u_staged_field.attrs['stagger'] + v_stagger = v_staged_field.attrs['stagger'] + + if u_stagger: + u_grd = self.set_interpolated_field(staged_field=u_staged_field, target_field=u_target_field) + else: + u_grd = self._data[u_stag_name] + + if v_stagger: + v_grd = self.set_interpolated_field(staged_field=v_staged_field, target_field=v_target_field) + else: + v_grd = self._data[v_stag_name] + + if u_ll_name[0] == 'U' and v_ll_name[0] == 'V': + ull, vll = vectorrot.ugrd_vgrd2ull_vll(u_grd, v_grd) + elif u_ll_name[:4] == 'wspd' and v_ll_name[:4] == 'wdir': + ull, vll = vectorrot.ugrd_vgrd2wspd_wdir(u_grd, v_grd) + else: + raise ValueError( + f"`u_ll_name' and `u_ll_name' must either start with 'U' and 'V' or 'wspd' and 'wdir', " + f"but they are u_ll_name={u_ll_name} and v_ll_name={v_ll_name}") + + self._data[u_ll_name] = ull + self._data[v_ll_name] = vll + + def set_interpolated_field(self, staged_field: xr.DataArray, target_field: xr.DataArray, + dropped_staged_attrs: List[str] =None, **kwargs): + stagger_attr_name = kwargs.pop('stagger_attr_name', 'stagger') + stagger = kwargs.pop('stagger', staged_field.attrs[stagger_attr_name]) + if dropped_staged_attrs is None: + dropped_staged_attrs = [stagger_attr_name] + + excluded_dims = [self.staged_dimension_mapping[stagger]] + target_dim_order = [dim.replace('_stag', '') for dim in staged_field.dims] + new_field = xr.broadcast(target_field, staged_field, exclude=excluded_dims)[0] + new_field = new_field.transpose(*target_dim_order) + new_field.attrs.update(remove_items(staged_field.attrs, dropped_staged_attrs)) + + new_field.data = VectorRotateLambertConformal2latlon.interpolate_to_grid_center(staged_field, + self.staged_dimension_mapping[stagger]) + return new_field + + + class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): DEFAULT_MODEL = "WRF-Chem" @@ -191,11 +285,12 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self._set_dims_as_coords() if external_coords_file is not None: self._apply_external_coordinates() + self.apply_staged_transormation() self.rechunk_data(self.rechunk_values) # self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name]) # self._set_nearest_coords() self._set_geoinfos() - print('dummy test') + logging.debug("SingleGridColumnWrfChemDataLoader Initialised") def _set_geoinfos(self): # identify nearest coords @@ -237,6 +332,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): for k, v in ds_coords.coords.variables.items(): data = data.assign_coords({k: (remove_items(list(v.dims), 'Time'), v.values.squeeze())}) self._data = data + ds_coords.close() logging.info('setup external coords') def _set_coords(self, coords): @@ -294,7 +390,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): class DataHandlerSingleGridColumn(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, *args, external_coords_file, **kwargs): + def __init__(self, *args, external_coords_file=None, **kwargs): self.external_coords_file = external_coords_file super().__init__(*args, **kwargs) @@ -324,10 +420,12 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # select defined variables at grid box or grid coloumn based on nearest icoords data = sgc_loader.data.isel(sgc_loader.get_nearest_icoords()).squeeze()[self.variables] # expand dimesion for iterdim + # data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim).compute() data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim) # transpose dataarray: set first three fixed and keep remaining as is data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...) - data = self._slice_prep(data, start=start, end=end) + data = dask.compute(self._slice_prep(data, start=start, end=end))[0] + sgc_loader.data.close() # data = self.check_for_negative_concentrations(data) # ToDo @@ -337,7 +435,8 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # ToDo # data, meta = None, None # raise NotImplementedError - return data.chunk({self.time_dim:-1}), meta + # return data.chunk({self.time_dim:-1}), meta + return data, meta def get_X(self, upsampling=False, as_numpy=False): if as_numpy is True: @@ -356,12 +455,16 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # # self.target_data = targets # raise NotImplementedError - # def make_samples(self): - # # self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) - # # self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) - # # self.make_observation(self.target_dim, self.target_var, self.time_dim) - # # self.remove_nan(self.time_dim) - # raise NotImplementedError + def make_samples(self): + self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) + self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) + self.make_observation(self.target_dim, self.target_var, self.time_dim) + self.remove_nan(self.time_dim) + # self.input_data = self.input_data.compute() + # self.label = self.label.compute() + # self.observation = self.observation.compute() + # self.target_data = self.target_data.compute() + self._data.close() # @TimeTrackingWrapper # def setup_samples(self): @@ -530,7 +633,19 @@ if __name__ == '__main__': windsect = WindSector() radius_from_point = 150. # in km - # dummy_plot(geo_info.bearing.where(dist_xr <= radius_from_point), True, [2, 15], [45, 58]) + ###test for rotation + vecrot = VectorRotateLambertConformal2latlon(xlat=wrf_new.data.XLAT.data, xlong=wrf_new.data.XLONG.data) + ugrd = vecrot.interpolate_to_grid_center(wrf_new.data.U, 'west_east_stag') + vgrd = vecrot.interpolate_to_grid_center(wrf_new.data.V, 'south_north_stag') + ull_grd, vll_grd = vecrot.ugrd_vgrd2ull_vll(ugrd, vgrd) + ull_stg, vll_stg = vecrot.ustg_vstg2ull_vll(wrf_new.data.U, wrf_new.data.V, + ustg_dim='west_east_stag', vstg_dim='south_north_stag') + + wspd_grd, wdir_grd = vecrot.ugrd_vgrd2wspd_wdir(ugrd, vgrd) + wspd_stg, wdir_stg = vecrot.ustg_vstg2wspd_wdir(wrf_new.data.U, wrf_new.data.V, + ustg_dim='west_east_stag', vstg_dim='south_north_stag') + ### end test rotation + for i, (data, xlim, ylim, kwargs) in enumerate(((wrf_new._data.T2.isel({'XTIME': 0}), [-42, 66], [23, 80], {'circle': True, 'cbar_kwargs': {'orientation': 'horizontal', 'pad': 0.01}}),