diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 176fdd8f70f884aca39301a9feec051e9c47234e..e69e63d33e1e6fd785557408cb403d138f89c45a 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -121,14 +121,12 @@ class BaseWrfChemDataLoader: def dataset_search_str(self): return os.path.join(self.data_path, self.common_file_starter + '*') - @TimeTrackingWrapper def open_data(self): logging.debug(f'open data: {self.dataset_search_str}') data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, parallel=True) self._data = data - @TimeTrackingWrapper def assign_coords(self, coords, **coords_kwargs): """ Assign coords to WrfChemDataHandler._data @@ -136,7 +134,6 @@ class BaseWrfChemDataLoader: """ self._data = self._data.assign_coords(coords, **coords_kwargs) - @TimeTrackingWrapper def rechunk_data(self, chunks=None, name_prefix='xarray-', token=None, lock=False): self._data = self._data.chunk(chunks=chunks, name_prefix=name_prefix, token=token, lock=lock) @@ -157,12 +154,10 @@ class BaseWrfChemDataLoader: status = 'FAIL' print(f'{i}: {filenamestr} {status}') - @TimeTrackingWrapper def get_distances(self, lat, lon): dist = haversine_dist(lat1=self._data.XLAT, lon1=self._data.XLONG, lat2=lat, lon2=lon) return dist - @TimeTrackingWrapper def get_bearing(self, lat, lon, points_last=True): bearing = bearing_angle(lat1=lat, lon1=lon, lat2=self._data.XLAT, lon2=self._data.XLONG) if points_last: @@ -170,7 +165,6 @@ class BaseWrfChemDataLoader: else: return bearing - @TimeTrackingWrapper def compute_nearest_icoordinates(self, lat, lon, dim=None): dist = self.get_distances(lat=lat, lon=lon) @@ -179,7 +173,6 @@ class BaseWrfChemDataLoader: else: return dist.argmin(dim) - @TimeTrackingWrapper def _set_dims_as_coords(self): if self._data is None: raise IOError(f'{self.__class__.__name__} can not set dims as coords. Use must use `open_data()` before.') @@ -189,7 +182,6 @@ class BaseWrfChemDataLoader: self._data = data logging.info('set dimensions as coordinates') - @TimeTrackingWrapper def apply_staged_transormation(self, mapping_of_stag2unstag=None): if mapping_of_stag2unstag is None: mapping_of_stag2unstag = {'U': 'U10', 'V': 'V10', 'U10': 'U10', 'V10': 'V10'} @@ -230,7 +222,6 @@ class BaseWrfChemDataLoader: self._data[u_ll_name] = ull self._data[v_ll_name] = vll - @TimeTrackingWrapper def set_interpolated_field(self, staged_field: xr.DataArray, target_field: xr.DataArray, dropped_staged_attrs: List[str] =None, **kwargs): stagger_attr_name = kwargs.pop('stagger_attr_name', 'stagger') @@ -291,7 +282,6 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): # self._set_geoinfos() logging.debug("SingleGridColumnWrfChemDataLoader Initialised") - @TimeTrackingWrapper def __enter__(self): self.open_data() @@ -311,7 +301,6 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self.data.close() gc.collect() - @TimeTrackingWrapper def _set_geoinfos(self): # identify nearest coords self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name]) @@ -346,7 +335,6 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): def geo_infos(self): return self._geo_infos - @TimeTrackingWrapper def _apply_external_coordinates(self): ds_coords = xr.open_dataset(self.external_coords_file, chunks={'south_north': 36, 'west_east': 40}) data = self._data @@ -356,7 +344,6 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): ds_coords.close() logging.debug('setup external coords') - @TimeTrackingWrapper def _set_coords(self, coords): __set_coords = dict(lat=None, lon=None) if len(coords) != 2: @@ -372,14 +359,12 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): else: raise TypeError(f"`coords' must be a tuple of floats or a dict, but is of type: {type(coords)}") - @TimeTrackingWrapper def get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]: if as_arrays: return np.array(self.__coords['lat']), np.array(self.__coords['lon']) else: return self.__coords - @TimeTrackingWrapper def _set_nearest_icoords(self, dim=None): lat, lon = self.get_coordinates(as_arrays=True) self._nearest_icoords = dask.compute(self.compute_nearest_icoordinates(lat, lon, dim))[0] @@ -643,7 +628,6 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): self.wind_dir_name = None super().__init__(*args, **kwargs) - @TimeTrackingWrapper def extract_data_from_loader(self, loader): wind_dir_name = self._get_wind_dir_var_name(loader) full_data = loader.data.isel(loader.get_nearest_icoords()).squeeze() @@ -706,12 +690,27 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): # sector_history = sector_history.assign_coords({self.target_dim: sector_history_var_names}) grid_data = self.preselect_and_transform_neighbouring_data_based_on_radius(loader) - + # grid_data = grid_data.expand_dims(self.iter_dim, -1).assign_coords( + # {self.iter_dim: self.history.coords[self.iter_dim].values}) + # sec_data_history_var_names = [f"{var}_sect" for var in self.history.coords[self.target_dim].values] + logging.info("preselect_and_transform_neighbouring_data_based_on_radius(loader)") for sect in existing_sectors: # select data in wind sector sec_data = self.get_section_data_from_circle(grid_data, loader, sect) sec_data = self.apply_aggregation_method_on_sector_data(sec_data, loader) - + sec_data_history = self._make_and_return_history_window(dim_name_of_shift=self.time_dim, + window=self.window_history_size, + data=sec_data.to_array(self.target_dim) + ) + sec_data_history = sec_data_history.broadcast_like(self.history) + sec_data_history = sec_data_history.transpose(*self.history.dims) + sector_history = xr.where(sector_allocation.squeeze() == sect, + sec_data_history.sel({self.time_dim: sector_history[self.time_dim]}), + sector_history * 1.) + + sector_history = sector_history.assign_coords({self.target_dim: sector_history_var_names}) + sector_history = sector_history.compute() + combined_history = xr.concat([self.history, sector_history], dim=self.target_dim) # loader.data.T2.where(loader.geo_infos.dist.sel({'points': 0}).drop('points') <= self.radius).where( # self.windsector.is_in_sector(sect, loader.geo_infos.bearing)) @@ -722,7 +721,9 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): # self.windsector.is_in_sector(sect, # loader.geo_infos.bearing.drop('points').squeeze())) - return self.history + return combined_history + else: + return self.history def get_section_data_from_circle(self, grid_data, loader, sect): sec_data = grid_data.where( @@ -761,7 +762,6 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): wind_dir_of_intrest = self.history.sel({self.target_dim: self.wind_dir_name, self.window_dim: 0}) return wind_dir_of_intrest - @TimeTrackingWrapper def get_applied_transdormation_on_wind_sector_edges(self): ws_edges = self._get_left_and_right_wind_sector_edges(return_as='xr.da', dim=self.wind_sector_edge_dim_name) ws_edges = self.apply_transformation_on_data(ws_edges)