diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index af4660c13bf6cc98ef82b92baf184f8585b606f4..ee6ecc849bfeb961059b0b4c343042e001cbd61b 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -11,6 +11,11 @@ from dask.diagnostics import ProgressBar from tzwhere import tzwhere from toarstats import toarstats +import hashlib +from functools import reduce, partial +import dill + + import dask import inspect import os @@ -18,6 +23,7 @@ import gc from mlair.helpers.geofunctions import haversine_dist, bearing_angle, WindSector, VectorRotateLambertConformal2latlon from mlair.helpers.helpers import convert2xrda, remove_items, to_list from mlair.helpers import TimeTrackingWrapper, TimeTracking +from mlair.configuration import check_path_and_create from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation @@ -71,7 +77,6 @@ class BaseWrfChemDataLoader: vars_to_rotate: Tuple[Tuple[Tuple[str, str], Tuple[str, str]]] = DEFAULT_VARS_TO_ROTATE, staged_dimension_mapping=None, stag_ending='_stag', date_format_of_nc_file=None, - ): """ Initialisze data loader @@ -396,12 +401,20 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): DEFAULT_ITER_DIM = "points" DEFAULT_WINDOW_DIM = "window" + _hash = ["data_path", "external_coords_file", "time_dim_name", + "rechunk_values", "variables", "z_coord_selector", "date_format_of_nc_file", + "wind_sectors", "wind_sector_edge_dim_name", "statistics_per_var", + "aggregation_sampling", "time_zone", + ] + def __init__(self, coords: Tuple[float_np_xr, float_np_xr], target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, external_coords_file: str = None, - wind_sectors=None, wind_sector_edge_dim_name=None, **kwargs): + wind_sectors=None, wind_sector_edge_dim_name=None, statistics_per_var: Dict = None, + aggregation_sampling: str = None, time_zone: str = None, station: str = None, + lazy_preprocessing: bool = False, **kwargs): """ @@ -419,6 +432,10 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): :type external_coords_file: :param wind_sectors: :type wind_sectors: + :param statistics_per_var: Dict containing the (TOAR-) statistics as value for each variable (key) + :param aggregation_sampling: sampling period for statistics (e.g. "daily", "seasonal" etc. See torstats for more details) + :param time_zone: + :param: station :param kwargs: :type kwargs: """ @@ -433,6 +450,17 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self.external_coords_file = external_coords_file self.wind_sectors = wind_sectors self.wind_sector_edge_dim_name = wind_sector_edge_dim_name + self.statistics_per_var = statistics_per_var + self.aggregation_sampling = aggregation_sampling + self.time_zone = time_zone + self.station = station + + self.lazy = lazy_preprocessing + self.lazy_path = None + if self.lazy is True: + self.lazy_path = os.path.join(self.data_path, "lazy_data", self.__class__.__name__) + check_path_and_create(self.lazy_path) + logging.debug("SingleGridColumnWrfChemDataLoader Initialised") @@ -452,11 +480,64 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self._set_dims_as_coords() if self.external_coords_file is not None: self._apply_external_coordinates() + self.apply_staged_transormation() - # self.rechunk_data(self.rechunk_values) self._set_geoinfos() + + if self.lazy is False: + self.reset_data_by_other(self.apply_toarstats()) + else: + self.load_lazy() + self.store_lazy() + + # self.rechunk_data(self.rechunk_values) return self + def reset_data_by_other(self, other: xr.Dataset): + attrs = self._data.attrs + self._data = other + # for var in other: + # self._data[var] = other[var] + + def store_lazy(self): + hash = self._get_hash() + filename = os.path.join(self.lazy_path, hash + ".nc") + if not os.path.exists(filename): + self._data.to_netcdf(filename, format="NETCDF4", engine="netcdf4") + # dill.dump(self._create_lazy_data(), file=open(filename, "wb")) + + # def _create_lazy_data(self): + # return [self._data ]#, self.meta, self.input_data, self.target_data] + + def load_lazy(self): + hash = self._get_hash() + filename = os.path.join(self.lazy_path, hash + ".nc") + try: + self._data = xr.open_mfdataset(filename) + # with open(filename, "rb") as pickle_file: + # lazy_data = dill.load(pickle_file) + # self._extract_lazy(lazy_data) + logging.debug(f"{self.station[0]}: used lazy data") + except FileNotFoundError: + logging.debug(f"{self.station[0]}: could not use lazy data") + self.reset_data_by_other(self.apply_toarstats()) + except OSError: + logging.debug(f"{self.station[0]}: could not use lazy data") + self.reset_data_by_other(self.apply_toarstats()) + + def _extract_lazy(self, lazy_data): + # _data, self.meta, _input_data, _target_data = lazy_data + # f_prep = partial(self._slice_prep, start=self.start, end=self.end) + self._data = lazy_data[0] + # self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) + + def _hash_list(self): + return sorted(list(set(self._hash))) + + def _get_hash(self): + hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() + return hashlib.md5(hash).hexdigest() + def __exit__(self, exc_type, exc_val, exc_tb): self.data.close() gc.collect() @@ -601,6 +682,108 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): else: return {k: list(v.values) for k, v in self._nearest_coords.items()} + def apply_toarstats(self): + dh = self.prepare_and_apply_toarstats(self.data, target_sampling=self.aggregation_sampling) + return dh + + @staticmethod + def _toarstats_wrapper(data, sampling, statistics, metadata, seasons=None, crops=None, + data_capture=None): + + return toarstats(sampling, statistics, data, metadata, seasons, crops, data_capture) + + + @TimeTrackingWrapper + def prepare_and_apply_toarstats(self, data, target_sampling="daily"): + meta = pd.DataFrame({self.physical_x_coord_name: self.get_coordinates()['lon'].tolist(), + self.physical_y_coord_name: self.get_coordinates()['lat'].tolist()}, + index=self.station) + local_time_zone = self.get_local_time_zone_from_lat_lon(meta=meta) + collector = [] + for var in self.statistics_per_var.keys(): + collector.append(self.__toarstats_aggregation(data, local_time_zone, meta, target_sampling, var)) + sampling_data = xr.merge(collector).dropna(self.physical_t_coord_name) + # sampling_data.attrs = data.attrs + missing_squeezed_coords = data.coords._names - sampling_data.coords._names + for coord in missing_squeezed_coords: + sampling_data.coords[coord] = data.coords[coord] + + return sampling_data + + def __toarstats_aggregation(self, data, local_time_zone, meta, target_sampling, var): + with TimeTracking(name=f"{self.station}: apply toarstats `{self.statistics_per_var[var]}({var})`"): + spatial_dims = list(remove_items(data[var].dims, self.physical_t_coord_name)) + df = data[var].to_dataframe()[[var]].reset_index(level=spatial_dims) + df = df[[var] + spatial_dims] + df = self.set_external_time_zone_and_convert_to_local_time_zone(df, local_time_zone) + df = df.groupby(spatial_dims) + df = df.apply(self._toarstats_wrapper, sampling=target_sampling, statistics=self.statistics_per_var[var], + metadata=(meta[self.physical_y_coord_name], meta[self.physical_x_coord_name])) + df.columns = [var] + df.index.set_names(df.index.names[:len(spatial_dims)] + [self.physical_t_coord_name], inplace=True) + # df = df.to_xarray().to_array(self.target_dim) + df = df.to_xarray() + df = df.chunk({self.logical_x_coord_name: -1}) + # collector.append(df) + return df + + def set_external_time_zone_and_convert_to_local_time_zone(self, data, local_time_zone): + """ + + :param data: + :type data: + :param local_time_zone: + :type local_time_zone: + :return: + :rtype: + """ + hdata = data + # hdata = data.squeeze().to_pandas() + hdata.index = self.set_time_zone(hdata.index) + hdata.index = hdata.index.tz_convert(local_time_zone) + logging.debug(f"Set local time zone to: {local_time_zone}") + return hdata + + def get_local_time_zone_from_lat_lon(self, lat=None, lon=None, meta=None): + """ + Retuns name of time zone for given lat lon coordinates. + + Method also accepts a meta data pd.DataFrame where lat and lon are extracted with the loader's physical x and y + coord names + + :param lat: + :type lat: + :param lon: + :type lon: + :param meta: + :type meta: pd.DataFrame + :return: + :rtype: str + """ + if (lat is None and lon is None) and (meta is not None): + lat = meta[self.physical_y_coord_name] + lon = meta[self.physical_x_coord_name] + + tz_where = tzwhere.tzwhere() + local_time_zone = tz_where.tzNameAt(latitude=lat, longitude=lon) + logging.debug(f"Detect local time zone '{local_time_zone}' based on lat={lat}, lon={lon}") + return local_time_zone + + def set_time_zone(self, time_index): + """ + Sets time zone information on a given index + + :param time_index: + :type time_index: + :return: + :rtype: + """ + + dti = pd.to_datetime(time_index) + dti = dti.tz_localize(self.time_zone) + logging.debug(f"Set external time zone for {self.station} to: {self.time_zone}") + return dti + class DataHandlerSingleGridColumn(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) @@ -709,7 +892,12 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # end_time=self.end, date_format_of_nc_file=self.date_format_of_nc_file, wind_sectors=helper_wind_sectors, - wind_sector_edge_dim_name=helper_wind_sector_edge_dim_name , + wind_sector_edge_dim_name=helper_wind_sector_edge_dim_name, + statistics_per_var=self.statistics_per_var, + aggregation_sampling=self.input_output_sampling4toarstats[1], + time_zone=self.time_zone, + station=self.station, + lazy_preprocessing=True, ) self.__loader = loader @@ -743,92 +931,92 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.loader.physical_y_coord_name: self.loader.get_coordinates()['lat'].tolist()} meta = pd.DataFrame(_meta, index=station) - if isinstance(self.input_output_sampling4toarstats, tuple) and len(self.input_output_sampling4toarstats) == 2: - if self.var_logical_z_coord_selector != 0: - raise NotImplementedError( - f"Method `apply_toarstats` is not implemented for var_logical_z_coord_selector != 0: " - f"Is {self.var_logical_z_coord_selector}") - data = self.prepare_and_apply_toarstats(data, meta, self.input_output_sampling4toarstats[1]) + # if isinstance(self.input_output_sampling4toarstats, tuple) and len(self.input_output_sampling4toarstats) == 2: + # if self.var_logical_z_coord_selector != 0: + # raise NotImplementedError( + # f"Method `apply_toarstats` is not implemented for var_logical_z_coord_selector != 0: " + # f"Is {self.var_logical_z_coord_selector}") + # data = self.prepare_and_apply_toarstats(data, meta, self.input_output_sampling4toarstats[1]) data = self._slice_prep(data, start=start, end=end) return data, meta - @TimeTrackingWrapper - def prepare_and_apply_toarstats(self, data, meta, target_sampling="daily"): - local_time_zone = self.get_local_time_zone_from_lat_lon(meta=meta) - hdata = self.set_external_time_zone_and_convert_to_local_time_zone(data, local_time_zone) - hsampling_data = [] - for i, var in enumerate(hdata.columns): - hdf = toarstats(target_sampling, self.statistics_per_var[var], hdata[var], - (meta[self.loader.physical_y_coord_name], meta[self.loader.physical_x_coord_name])) - hsampling_data.append(xr.DataArray(hdf, coords=[hdf.index, [var]], - dims=[self.loader.physical_t_coord_name, self.target_dim])) - sampling_data = xr.concat(hsampling_data, dim=self.target_dim) - sampling_data = sampling_data.broadcast_like(data, exclude=self.loader.physical_t_coord_name).dropna( - self.loader.physical_t_coord_name) - sampling_data.attrs = data.attrs - missing_squeezed_coords = data.coords._names - sampling_data.coords._names - for coord in missing_squeezed_coords: - sampling_data.coords[coord] = data.coords[coord] - - return self._force_dask_computation(sampling_data) - - def set_external_time_zone_and_convert_to_local_time_zone(self, data, local_time_zone): - """ - - :param data: - :type data: - :param local_time_zone: - :type local_time_zone: - :return: - :rtype: - """ - hdata = data.squeeze().to_pandas() - hdata.index = self.set_time_zone(hdata.index) - hdata.index = hdata.index.tz_convert(local_time_zone) - logging.debug(f"Set local time zone for {self.station} to: {local_time_zone}") - return hdata - - def get_local_time_zone_from_lat_lon(self, lat=None, lon=None, meta=None): - """ - Retuns name of time zone for given lat lon coordinates. - - Method also accepts a meta data pd.DataFrame where lat and lon are extracted with the loader's physical x and y - coord names - - :param lat: - :type lat: - :param lon: - :type lon: - :param meta: - :type meta: pd.DataFrame - :return: - :rtype: str - """ - if (lat is None and lon is None) and (meta is not None): - lat = meta[self.loader.physical_y_coord_name] - lon = meta[self.loader.physical_x_coord_name] - - - tz_where = tzwhere.tzwhere() - local_time_zone = tz_where.tzNameAt(latitude=lat, longitude=lon) - logging.debug(f"Detect local time zone '{local_time_zone}' based on lat={lat}, lon={lon}") - return local_time_zone - - def set_time_zone(self, time_index): - """ - Sets time zone information on a given index - - :param time_index: - :type time_index: - :return: - :rtype: - """ - - dti = pd.to_datetime(time_index) - dti = dti.tz_localize(self.time_zone) - logging.debug(f"Set external time zone for {self.station} to: {self.time_zone}") - return dti + # @TimeTrackingWrapper + # def prepare_and_apply_toarstats(self, data, meta, target_sampling="daily"): + # local_time_zone = self.get_local_time_zone_from_lat_lon(meta=meta) + # hdata = self.set_external_time_zone_and_convert_to_local_time_zone(data, local_time_zone) + # hsampling_data = [] + # for i, var in enumerate(hdata.columns): + # hdf = toarstats(target_sampling, self.statistics_per_var[var], hdata[var], + # (meta[self.loader.physical_y_coord_name], meta[self.loader.physical_x_coord_name])) + # hsampling_data.append(xr.DataArray(hdf, coords=[hdf.index, [var]], + # dims=[self.loader.physical_t_coord_name, self.target_dim])) + # sampling_data = xr.concat(hsampling_data, dim=self.target_dim) + # sampling_data = sampling_data.broadcast_like(data, exclude=self.loader.physical_t_coord_name).dropna( + # self.loader.physical_t_coord_name) + # sampling_data.attrs = data.attrs + # missing_squeezed_coords = data.coords._names - sampling_data.coords._names + # for coord in missing_squeezed_coords: + # sampling_data.coords[coord] = data.coords[coord] + # + # return self._force_dask_computation(sampling_data) + # + # def set_external_time_zone_and_convert_to_local_time_zone(self, data, local_time_zone): + # """ + # + # :param data: + # :type data: + # :param local_time_zone: + # :type local_time_zone: + # :return: + # :rtype: + # """ + # hdata = data.squeeze().to_pandas() + # hdata.index = self.set_time_zone(hdata.index) + # hdata.index = hdata.index.tz_convert(local_time_zone) + # logging.debug(f"Set local time zone for {self.station} to: {local_time_zone}") + # return hdata + # + # def get_local_time_zone_from_lat_lon(self, lat=None, lon=None, meta=None): + # """ + # Retuns name of time zone for given lat lon coordinates. + # + # Method also accepts a meta data pd.DataFrame where lat and lon are extracted with the loader's physical x and y + # coord names + # + # :param lat: + # :type lat: + # :param lon: + # :type lon: + # :param meta: + # :type meta: pd.DataFrame + # :return: + # :rtype: str + # """ + # if (lat is None and lon is None) and (meta is not None): + # lat = meta[self.loader.physical_y_coord_name] + # lon = meta[self.loader.physical_x_coord_name] + # + # + # tz_where = tzwhere.tzwhere() + # local_time_zone = tz_where.tzNameAt(latitude=lat, longitude=lon) + # logging.debug(f"Detect local time zone '{local_time_zone}' based on lat={lat}, lon={lon}") + # return local_time_zone + # + # def set_time_zone(self, time_index): + # """ + # Sets time zone information on a given index + # + # :param time_index: + # :type time_index: + # :return: + # :rtype: + # """ + # + # dti = pd.to_datetime(time_index) + # dti = dti.tz_localize(self.time_zone) + # logging.debug(f"Set external time zone for {self.station} to: {self.time_zone}") + # return dti @staticmethod def _extract_largest_coord_extractor(var_extarctor, target_extractor) -> Union[List, None]: @@ -923,9 +1111,8 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): """ return all(self._transformation) - @staticmethod - def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, - **kwargs): + def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None, + use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs): """ Interpolate values according to different methods. @@ -962,6 +1149,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): :return: xarray.DataArray """ + # data = self.create_full_time_dim(data, dim, sampling) return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs)