diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 409a31e476448413f562df319a4c840df39d9fb0..c7313ba62f314e35bba14790b1e4d32c7ad10964 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -12,7 +12,7 @@ import os import gc from mlair.helpers.geofunctions import haversine_dist, bearing_angle, WindSector, VectorRotateLambertConformal2latlon from mlair.helpers.helpers import convert2xrda, remove_items, to_list -from mlair import helpers +from mlair.helpers import TimeTrackingWrapper from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation @@ -121,12 +121,14 @@ class BaseWrfChemDataLoader: def dataset_search_str(self): return os.path.join(self.data_path, self.common_file_starter + '*') + @TimeTrackingWrapper def open_data(self): logging.debug(f'open data: {self.dataset_search_str}') data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, parallel=True) self._data = data + @TimeTrackingWrapper def assign_coords(self, coords, **coords_kwargs): """ Assign coords to WrfChemDataHandler._data @@ -134,6 +136,7 @@ class BaseWrfChemDataLoader: """ self._data = self._data.assign_coords(coords, **coords_kwargs) + @TimeTrackingWrapper def rechunk_data(self, chunks=None, name_prefix='xarray-', token=None, lock=False): self._data = self._data.chunk(chunks=chunks, name_prefix=name_prefix, token=token, lock=lock) @@ -154,10 +157,12 @@ class BaseWrfChemDataLoader: status = 'FAIL' print(f'{i}: {filenamestr} {status}') + @TimeTrackingWrapper def get_distances(self, lat, lon): dist = haversine_dist(lat1=self._data.XLAT, lon1=self._data.XLONG, lat2=lat, lon2=lon) return dist + @TimeTrackingWrapper def get_bearing(self, lat, lon, points_last=True): bearing = bearing_angle(lat1=lat, lon1=lon, lat2=self._data.XLAT, lon2=self._data.XLONG) if points_last: @@ -165,6 +170,7 @@ class BaseWrfChemDataLoader: else: return bearing + @TimeTrackingWrapper def compute_nearest_icoordinates(self, lat, lon, dim=None): dist = self.get_distances(lat=lat, lon=lon) @@ -173,6 +179,7 @@ class BaseWrfChemDataLoader: else: return dist.argmin(dim) + @TimeTrackingWrapper def _set_dims_as_coords(self): if self._data is None: raise IOError(f'{self.__class__.__name__} can not set dims as coords. Use must use `open_data()` before.') @@ -182,6 +189,7 @@ class BaseWrfChemDataLoader: self._data = data logging.info('set dimensions as coordinates') + @TimeTrackingWrapper def apply_staged_transormation(self, mapping_of_stag2unstag=None): if mapping_of_stag2unstag is None: mapping_of_stag2unstag = {'U': 'U10', 'V': 'V10', 'U10': 'U10', 'V10': 'V10'} @@ -222,6 +230,7 @@ class BaseWrfChemDataLoader: self._data[u_ll_name] = ull self._data[v_ll_name] = vll + @TimeTrackingWrapper def set_interpolated_field(self, staged_field: xr.DataArray, target_field: xr.DataArray, dropped_staged_attrs: List[str] =None, **kwargs): stagger_attr_name = kwargs.pop('stagger_attr_name', 'stagger') @@ -282,6 +291,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): # self._set_geoinfos() logging.debug("SingleGridColumnWrfChemDataLoader Initialised") + @TimeTrackingWrapper def __enter__(self): self.open_data() @@ -301,6 +311,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self.data.close() gc.collect() + @TimeTrackingWrapper def _set_geoinfos(self): # identify nearest coords self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name]) @@ -335,6 +346,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): def geo_infos(self): return self._geo_infos + @TimeTrackingWrapper def _apply_external_coordinates(self): ds_coords = xr.open_dataset(self.external_coords_file, chunks={'south_north': 36, 'west_east': 40}) data = self._data @@ -344,6 +356,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): ds_coords.close() logging.debug('setup external coords') + @TimeTrackingWrapper def _set_coords(self, coords): __set_coords = dict(lat=None, lon=None) if len(coords) != 2: @@ -359,12 +372,14 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): else: raise TypeError(f"`coords' must be a tuple of floats or a dict, but is of type: {type(coords)}") + @TimeTrackingWrapper def get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]: if as_arrays: return np.array(self.__coords['lat']), np.array(self.__coords['lon']) else: return self.__coords + @TimeTrackingWrapper def _set_nearest_icoords(self, dim=None): lat, lon = self.get_coordinates(as_arrays=True) self._nearest_icoords = dask.compute(self.compute_nearest_icoordinates(lat, lon, dim))[0] @@ -463,6 +478,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): ) self.__loader = loader + @TimeTrackingWrapper def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, store_data_locally=False, data_origin: Dict = None, start=None, end=None): @@ -527,7 +543,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # # self.input_data = inputs # # self.target_data = targets # raise NotImplementedError - + @TimeTrackingWrapper def make_samples(self): self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) if self.var_logical_z_coord_selector is not None: @@ -573,6 +589,14 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): """ return self.observation + @property + def transformation_is_applied(self): + """ + Returns True if transformation is applied. + Checks if all entries in `self._transform' are not None/empty + """ + return all(self._transformation) + # @TimeTrackingWrapper # def setup_samples(self): # """ @@ -602,29 +626,32 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): # _requirements2 = remove_items(inspect.getfullargspec(DataHandlerSingleGridColumn).args, ["self", "station"]) _requirements = DataHandlerWRF.requirements() - def __init__(self, *args, radius=None, sectors=None, sector_dim_name=None, **kwargs): + def __init__(self, *args, radius=None, sectors=None, wind_sector_edge_dim_name=None, **kwargs): if radius is None: radius = 100 # km self.radius = radius if sectors is None: sectors = WindSector.DEFAULT_WIND_SECTORS self.sectors = sectors - if sector_dim_name is None: - sector_dim_name = 'wind_sector' - self.sector_dim_name = sector_dim_name - self.windsector = WindSector(wind_sectors=self.sectors) + if wind_sector_edge_dim_name is None: + wind_sector_edge_dim_name = 'edges' + self.wind_sector_edge_dim_name = wind_sector_edge_dim_name + self.windsector = WindSector(wind_sectors=self.sectors, + wind_sector_edge_dim_name=self.wind_sector_edge_dim_name) self._added_vars = [] + self.wind_dir_name = None super().__init__(*args, **kwargs) + @TimeTrackingWrapper def extract_data_from_loader(self, loader): wind_dir_name = self._get_wind_dir_var_name(loader) full_data = loader.data.isel(loader.get_nearest_icoords()).squeeze() data = full_data[self.variables] sec_data = self.windsector.get_sect_of_value(full_data[wind_dir_name]) + self.sec_data = sec_data if wind_dir_name not in data: data[wind_dir_name] = full_data[wind_dir_name] self._added_vars.append(to_list(wind_dir_name)) - # data[self.sector_dim_name] = sec_data return data def _get_wind_dir_var_name(self, loader, wdir_name3d='wdirll', wdir_name2d='wdir10ll'): @@ -645,15 +672,49 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): raise ValueError(f"invalid wind var name passed to `_get_wind_dir_var_name'") assert var_name in itertools.chain(*itertools.chain(*loader.vars_to_rotate)) + if self.wind_dir_name is None: + self.wind_dir_name = var_name return var_name + def _get_left_and_right_wind_sector_edges(self, **kwargs): + ws_edges = self.windsector.wind_sector_edges_data(**kwargs) + ws_edges = ws_edges.expand_dims(dim={self.target_dim: to_list(self.wind_dir_name)}) + return ws_edges + + def get_transformation_opts(self, base=0): + pos_of_trafo = self.get_transformation_base(base) + trafo_opts = self._transformation[pos_of_trafo] + return trafo_opts + + def apply_transformation_on_wind_sector_edges(self, ws_edges): + ws_trafo_opts = self.get_transformation_opts() + ws_trafo, _ = self.transform(ws_edges, dim=self.aggregation_dim, inverse=False, + opts=ws_trafo_opts, transformation_dim=self.target_dim) + return ws_trafo + def modify_history(self): - ws_data = self.windsector.wind_sector_edges_data() - # t, topts = self.transform(ws_data, opts=self._transformation[0]['wdir10ll'], - # transformation_dim='index') - t = self.transform(self.windsector.wind_sector_edges_data()) + if self.transformation_is_applied: + ws_edges = self.get_applied_transdormation_on_wind_sector_edges() + wind_dir_of_interest = self.compute_wind_dir_of_interest() + sector_allocation = self.windsector.get_sect_of_value(value=wind_dir_of_interest, external_edges=ws_edges) + existing_sectors = np.unique(sector_allocation.data) + with self.loader as loader: + pass + #circular_data = loader.data[self.variables].where(loader.geo_infos.dist.squeeze() <= self.radius) + + return self.history + def compute_wind_dir_of_interest(self): + wind_dir_of_intrest = self.history.sel({self.target_dim: self.wind_dir_name, self.window_dim: 0}) + return wind_dir_of_intrest + + @TimeTrackingWrapper + def get_applied_transdormation_on_wind_sector_edges(self): + ws_edges = self._get_left_and_right_wind_sector_edges(return_as='xr.da', dim=self.wind_sector_edge_dim_name) + ws_edges = self.apply_transformation_on_wind_sector_edges(ws_edges) + return ws_edges + # def set_inputs_and_targets(self): # inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables) + helpers.to_list(self.sector_dim_name)}) # targets = self._data.sel( diff --git a/mlair/helpers/geofunctions.py b/mlair/helpers/geofunctions.py index 1546d2be2651c8d0c862da0a564454a2ecae0d9e..3bf4415f51ab6cafe27a9975ae85e2452ed59c70 100644 --- a/mlair/helpers/geofunctions.py +++ b/mlair/helpers/geofunctions.py @@ -139,11 +139,19 @@ def bearing_angle(lat1: xr_int_float, lon1: xr_int_float, class WindSector: DEFAULT_WIND_SECTORS = ['N', 'NNE', 'NE', 'ENE', 'E', 'ESE', 'SE', 'SSE', 'S', 'SSW', 'SW', 'WSW', 'W', 'WNW', 'NW', 'NNW'] + DEFAULT_WIND_SECTOR_EDGE_NAMES = ["left_edge", "right_edge"] + DEFAULT_WIND_SECTOR_EDGE_DIM_NAME = "edges" - def __init__(self, wind_sectors=None): + def __init__(self, wind_sectors=None, wind_sector_edge_dim_name=None, wind_sector_edge_names=None): if wind_sectors is None: wind_sectors = self.DEFAULT_WIND_SECTORS self._set_wind_sectores(wind_sectors) + if wind_sector_edge_dim_name is None: + wind_sector_edge_dim_name = self.DEFAULT_WIND_SECTOR_EDGE_DIM_NAME + self.wind_sector_edge_dim_name = wind_sector_edge_dim_name + if wind_sector_edge_names is None: + wind_sector_edge_names = self.DEFAULT_WIND_SECTOR_EDGE_NAMES + self.wind_sector_edge_names = wind_sector_edge_names def _set_wind_sectores(self, wind_sectores): # adapted from https://gitlab.version.fz-juelich.de/toar/geolocationservices/-/blob/master/utils/geoutils.py @@ -157,15 +165,17 @@ class WindSector: self.edges_per_sector = edges_per_sector self.wind_sectore_edges = wind_sector_edges - def wind_sector_edges_data(self, return_as="xr"): + def wind_sector_edges_data(self, return_as="xr.ds", **kwargs): data = pd.DataFrame.from_dict(self.wind_sectore_edges, orient="index", - columns=["left_edge", "right_edge"]) + columns=self.wind_sector_edge_names) if return_as == "pd": return data - elif return_as == "xr": + elif return_as == "xr.ds": return data.to_xarray() + elif return_as == 'xr.da': + return data.to_xarray().to_array(**kwargs) else: - raise ValueError(f"`return_as' must be 'pd' or 'xr'. But is {return_as}") + raise ValueError(f"`return_as' must be 'pd', 'xr.ds' or 'xr.da'. But is {return_as}") @staticmethod def _is_value_in_sector(value, left_edge, right_edge) -> bool: @@ -180,14 +190,21 @@ class WindSector: else: raise ValueError(f"`left_edge' and `right_edge' must not be the same.") - def is_in_sector(self, sector: str, value) -> bool: - left_edge, right_edge = self.wind_sectore_edges[sector] + def is_in_sector(self, sector: str, value, **kwargs) -> bool: + edges = kwargs.pop('external_edges', self.wind_sectore_edges) + if isinstance(edges, dict): + left_edge, right_edge = edges[sector] + elif isinstance(edges, xr.DataArray): + left_edge = edges.sel({'index': sector, self.wind_sector_edge_dim_name: self.wind_sector_edge_names[0]}) + right_edge = edges.sel({'index': sector, self.wind_sector_edge_dim_name: self.wind_sector_edge_names[1]}) + else: + raise TypeError(f"`edges must be a dict or xr.DataArray but is of type {type(edges)}'") return self._is_value_in_sector(value, left_edge, right_edge) - def get_sect_of_value(self, value): + def get_sect_of_value(self, value, **kwargs): sec_collector = xr.ones_like(value) for sect in self.wind_sectore_edges.keys(): - sec_collector = xr.where(self.is_in_sector(sect, value), sect, sec_collector) + sec_collector = xr.where(self.is_in_sector(sect, value, **kwargs), sect, sec_collector) return sec_collector