diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py index 198147842dc3fe2a606d71bbbeed479148824124..cb5aa5db0f29cf51d32ed54e810fa9b363d80cc6 100644 --- a/src/data_handling/__init__.py +++ b/src/data_handling/__init__.py @@ -10,6 +10,6 @@ __date__ = '2020-04-17' from .bootstraps import BootStraps -from .data_preparation_join import DataPrep +from .data_preparation_join import DataPrepJoin from .data_generator import DataGenerator from .data_distributor import Distributor diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 672ef8a1355441f2481bbf41a40cc951f334a30f..7b83b56f8c7f6b79f0f598d3f1b3d33c34df53bf 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -13,7 +13,7 @@ import keras import xarray as xr from src import helpers -from src.data_handling.data_preparation_join import DataPrep +from src.data_handling.data_preparation_join import DataPrepJoin from src.helpers.join import EmptyQueryResult number = Union[float, int] @@ -210,8 +210,8 @@ class DataGenerator(keras.utils.Sequence): std = None for station in self.stations: try: - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) chunks = (1, 100, data.data.shape[2]) tmp.append(da.from_array(data.data.data, chunks=chunks)) except EmptyQueryResult: @@ -249,8 +249,8 @@ class DataGenerator(keras.utils.Sequence): std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) for station in self.stations: try: - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) data.transform("datetime", method=method) mean = mean.combine_first(data.mean) std = std.combine_first(data.std) @@ -260,7 +260,7 @@ class DataGenerator(keras.utils.Sequence): return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, - save_local_tmp_storage: bool = True) -> DataPrep: + save_local_tmp_storage: bool = True) -> DataPrepJoin: """ Create DataPrep object and preprocess data for given key. @@ -288,8 +288,8 @@ class DataGenerator(keras.utils.Sequence): data = self._load_pickle_data(station, self.variables) except FileNotFoundError: logging.debug(f"load not pickle data for {station}") - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) if self.transformation is not None: data.transform("datetime", **helpers.remove_items(self.transformation, "scope")) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index e85d8a3ac732a2ce70a715cc7d6e6e21eee6b32b..366cce7629c3c4070d05f0b91e3fbbf5d556184a 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -78,7 +78,7 @@ class AbstractDataPrep(object): else: raise NotImplementedError("Either select hourly data or provide statistics_per_var.") - def load_data(self): + def load_data(self, source_name=""): """ Load data and meta data either from local disk (preferred) or download new data by using a custom download method. @@ -86,31 +86,33 @@ class AbstractDataPrep(object): cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not set, it is assumed, that data should be saved locally. """ + source_name = source_name if len(source_name) == 0 else f" from {source_name}" check_path_and_create(self.path) file_name = self._set_file_name() meta_file = self._set_meta_file_name() if self.kwargs.get('overwrite_local_data', False): - logging.debug(f"overwrite_local_data is true, therefore reload {file_name} from JOIN") + logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") if os.path.exists(file_name): os.remove(file_name) if os.path.exists(meta_file): os.remove(meta_file) - self.download_data(file_name, meta_file) - logging.debug("loaded new data from JOIN") + data, self.meta = self.download_data(file_name, meta_file) + logging.debug(f"loaded new data{source_name}") else: try: logging.debug(f"try to load local data from: {file_name}") data = xr.open_dataarray(file_name) self.meta = pd.read_csv(meta_file, index_col=0) + self.check_station_meta() logging.debug("loading finished") except FileNotFoundError as e: logging.debug(e) - logging.debug("load new data from JOIN") + logging.debug(f"load new data{source_name}") data, self.meta = self.download_data(file_name, meta_file) logging.debug("loading finished") - # create slices and check for negative concentration. - data = self._slice_prep(data) - self.data = self.check_for_negative_concentrations(data) + # create slices and check for negative concentration. + data = self._slice_prep(data) + self.data = self.check_for_negative_concentrations(data) def download_data(self, file_name, meta_file) -> [xr.DataArray, pd.DataFrame]: """ @@ -121,6 +123,14 @@ class AbstractDataPrep(object): """ raise NotImplementedError + def check_station_meta(self): + """ + Placeholder function to implement some additional station meta data check if desired. + + Ideally, this method should raise a FileNotFoundError if a value mismatch to load fresh data from a source. + """ + pass + def _set_file_name(self): all_vars = sorted(self.statistics_per_var.keys()) return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc") diff --git a/src/data_handling/data_preparation_join.py b/src/data_handling/data_preparation_join.py index 4313891131ddbc4699de6de499c1c3512e2a21dd..1c6593d65b5bdb4484bdf468c176c1becfba3981 100644 --- a/src/data_handling/data_preparation_join.py +++ b/src/data_handling/data_preparation_join.py @@ -4,6 +4,7 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-10-16' import datetime as dt +import inspect import logging import os from functools import reduce @@ -16,6 +17,7 @@ import xarray as xr from src.configuration import check_path_and_create from src import helpers from src.helpers import join, statistics +from src.data_handling.data_preparation import AbstractDataPrep # define a more general date type for type hinting date = Union[dt.date, dt.datetime] @@ -25,7 +27,7 @@ num_or_list = Union[number, List[number]] data_or_none = Union[xr.DataArray, None] -class DataPrep(object): +class DataPrepJoin(AbstractDataPrep): """ This class prepares data to be used in neural networks. @@ -57,65 +59,11 @@ class DataPrep(object): def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str], station_type: str = None, **kwargs): - """Construct instance.""" - self.path = os.path.abspath(path) self.network = network - self.station = helpers.to_list(station) - self.variables = variables self.station_type = station_type - self.mean: data_or_none = None - self.std: data_or_none = None - self.history: data_or_none = None - self.label: data_or_none = None - self.observation: data_or_none = None - self.extremes_history: data_or_none = None - self.extremes_label: data_or_none = None - self.kwargs = kwargs - self.data = None - self.meta = None - self._transform_method = None - self.statistics_per_var = kwargs.get("statistics_per_var", None) - self.sampling = kwargs.get("sampling", "daily") - if self.statistics_per_var is not None or self.sampling == "hourly": - self.load_data() - else: - raise NotImplementedError("Either select hourly data or provide statistics_per_var.") - - def load_data(self): - """ - Load data and meta data either from local disk (preferred) or download new data from TOAR database. - - Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both - cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not - set, it is assumed, that data should be saved locally. - """ - check_path_and_create(self.path) - file_name = self._set_file_name() - meta_file = self._set_meta_file_name() - if self.kwargs.get('overwrite_local_data', False): - logging.debug(f"overwrite_local_data is true, therefore reload {file_name} from JOIN") - if os.path.exists(file_name): - os.remove(file_name) - if os.path.exists(meta_file): - os.remove(meta_file) - self.download_data(file_name, meta_file) - logging.debug("loaded new data from JOIN") - else: - try: - logging.debug(f"try to load local data from: {file_name}") - data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - if self.station_type is not None: - self.check_station_meta() - logging.debug("loading finished") - except FileNotFoundError as e: - logging.debug(e) - logging.debug("load new data from JOIN") - data, self.meta = self.download_data(file_name, meta_file) - logging.debug("loading finished") - # create slices and check for negative concentration. - data = self._slice_prep(data) - self.data = self.check_for_negative_concentrations(data) + params = helpers.remove_items(inspect.getfullargspec(AbstractDataPrep.__init__).args, "self") + kwargs = {**{k: v for k, v in locals().items() if k in params and v is not None}, **kwargs} + super().__init__(**kwargs) def download_data(self, file_name, meta_file): """ @@ -133,13 +81,14 @@ class DataPrep(object): Will raise a FileNotFoundError if the values mismatch. """ - check_dict = {"station_type": self.station_type, "network_name": self.network} - for (k, v) in check_dict.items(): - if self.meta.at[k, self.station[0]] != v: - logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " - f"grapping from web.") - raise FileNotFoundError + if self.station_type is not None: + check_dict = {"station_type": self.station_type, "network_name": self.network} + for (k, v) in check_dict.items(): + if self.meta.at[k, self.station[0]] != v: + logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " + f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " + f"grapping from web.") + raise FileNotFoundError def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: """ @@ -166,426 +115,12 @@ class DataPrep(object): meta.to_csv(meta_file) return xarr, meta - def _set_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc") - - def _set_meta_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv") - def __repr__(self): """Represent class attributes.""" return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})" - def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, - **kwargs): - """ - Interpolate values according to different methods. - - (Copy paste from dataarray.interpolate_na) - - :param dim: - Specifies the dimension along which to interpolate. - :param method: - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', - 'polynomial', 'barycentric', 'krog', 'pchip', - 'spline', 'akima'}, optional - String indicating which method to use for interpolation: - - - 'linear': linear interpolation (Default). Additional keyword - arguments are passed to ``numpy.interp`` - - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', - 'polynomial': are passed to ``scipy.interpolate.interp1d``. If - method=='polynomial', the ``order`` keyword argument must also be - provided. - - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their - respective``scipy.interpolate`` classes. - :param limit: - default None - Maximum number of consecutive NaNs to fill. Must be greater than 0 - or None for no limit. - :param use_coordinate: - default True - Specifies which index to use as the x values in the interpolation - formulated as `y = f(x)`. If False, values are treated as if - eqaully-spaced along `dim`. If True, the IndexVariable `dim` is - used. If use_coordinate is a string, it specifies the name of a - coordinate variariable to use as the index. - :param kwargs: - - :return: xarray.DataArray - """ - self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs) - - @staticmethod - def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None: - """ - Support inverse_transformation method. - - Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas - normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements. - - :param mean: data with all mean values - :param std: data with all standard deviation values - :param method: name of transformation method - """ - msg = "" - if method in ['standardise', 'centre'] and mean is None: - msg += "mean, " - if method == 'standardise' and std is None: - msg += "std, " - if len(msg) > 0: - raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") - - def inverse_transform(self) -> None: - """ - Perform inverse transformation. - - Will raise an AssertionError, if no transformation was performed before. Checks first, if all required - statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by - new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the - current data is not transformed. - """ - - def f_inverse(data, mean, std, method_inverse): - if method_inverse == 'standardise': - return statistics.standardise_inverse(data, mean, std), None, None - elif method_inverse == 'centre': - return statistics.centre_inverse(data, mean), None, None - elif method_inverse == 'normalise': - raise NotImplementedError - else: - raise NotImplementedError - - if self._transform_method is None: - raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.") - self.check_inverse_transform_params(self.mean, self.std, self._transform_method) - self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method) - self._transform_method = None - - def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None, - std=None) -> None: - """ - Transform data according to given transformation settings. - - This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 - and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale - (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This - method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the - internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). - - :param string/int dim: This param is not used for inverse transformation. - | for xarray.DataArray as string: name of dimension which should be standardised - | for pandas.DataFrame as int: axis of dimension which should be standardised - :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented - yet. This param is not used for inverse transformation. - :param inverse: Switch between transformation and inverse transformation. - - :return: xarray.DataArrays or pandas.DataFrames: - #. mean: Mean of data - #. std: Standard deviation of data - #. data: Standardised data - """ - - def f(data): - if method == 'standardise': - return statistics.standardise(data, dim) - elif method == 'centre': - return statistics.centre(data, dim) - elif method == 'normalise': - # use min/max of data or given min/max - raise NotImplementedError - else: - raise NotImplementedError - - def f_apply(data): - if method == "standardise": - return mean, std, statistics.standardise_apply(data, mean, std) - elif method == "centre": - return mean, None, statistics.centre_apply(data, mean) - else: - raise NotImplementedError - - if not inverse: - if self._transform_method is not None: - raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with " - f"{self._transform_method}. Please perform inverse transformation of data first.") - self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data) - self._transform_method = method - else: - self.inverse_transform() - - def get_transformation_information(self, variable: str) -> Tuple[data_or_none, data_or_none, str]: - """ - Extract transformation statistics and method. - - Get mean and standard deviation for given variable and the transformation method if set. If a transformation - depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are - returned with None as fill value. - - :param variable: Variable for which the information on transformation is requested. - - :return: mean, standard deviation and transformation method - """ - try: - mean = self.mean.sel({'variables': variable}).values - except AttributeError: - mean = None - try: - std = self.std.sel({'variables': variable}).values - except AttributeError: - std = None - return mean, std, self._transform_method - - def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing history data. - - Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted - data. This is used to represent history in the data. Results are stored in history attribute. - - :param dim_name_of_inputs: Name of dimension which contains the input variables - :param window: number of time steps to look back in history - Note: window will be treated as negative value. This should be in agreement with looking back on - a time line. Nonetheless positive values are allowed but they are converted to its negative - expression - :param dim_name_of_shift: Dimension along shift will be applied - """ - window = -abs(window) - self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables}) - - def shift(self, dim: str, window: int) -> xr.DataArray: - """ - Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). - - :param dim: dimension along shift is applied - :param window: number of steps to shift (corresponds to the window length) - - :return: shifted data - """ - start = 1 - end = 1 - if window <= 0: - start = window - else: - end = window + 1 - res = [] - for w in range(start, end): - res.append(self.data.shift({dim: -w})) - window_array = self.create_index_array('window', range(start, end)) - res = xr.concat(res, dim=window_array) - return res - - def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, - window: int) -> None: - """ - Create a xr.DataArray containing labels. - - Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label - attribute. - - :param dim_name_of_target: Name of dimension which contains the target variable - :param target_var: Name of target variable in 'dimension' - :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied - :param window: lead time of label - """ - window = abs(window) - self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var}) - - def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing observations. - - Observations are defined as value of the current time step t. Set observation attribute. - - :param dim_name_of_target: Name of dimension which contains the observation variable - :param target_var: Name of observation variable(s) in 'dimension' - :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied - """ - self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var}) - - def remove_nan(self, dim: str) -> None: - """ - Remove all NAs slices along dim which contain nans in history, label and observation. - - This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute. - - :param dim: dimension along the remove is performed. - """ - intersect = [] - if (self.history is not None) and (self.label is not None): - non_nan_history = self.history.dropna(dim=dim) - non_nan_label = self.label.dropna(dim=dim) - non_nan_observation = self.observation.dropna(dim=dim) - intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, - non_nan_observation.coords[dim].values)) - - min_length = self.kwargs.get("min_length", 0) - if len(intersect) < max(min_length, 1): - self.history = None - self.label = None - self.observation = None - else: - self.history = self.history.sel({dim: intersect}) - self.label = self.label.sel({dim: intersect}) - self.observation = self.observation.sel({dim: intersect}) - - @staticmethod - def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray: - """ - Create an 1D xr.DataArray with given index name and value. - - :param index_name: name of dimension - :param index_value: values of this dimension - - :return: this array - """ - ind = pd.DataFrame({'val': index_value}, index=index_value) - res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True) - res.name = index_name - return res - - def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray: - """ - Set start and end date for slicing and execute self._slice(). - - :param data: data to slice - :param coord: name of axis to slice - - :return: sliced data - """ - start = self.kwargs.get('start', data.coords[coord][0].values) - end = self.kwargs.get('end', data.coords[coord][-1].values) - return self._slice(data, start, end, coord) - - @staticmethod - def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: - """ - Slice through a given data_item (for example select only values of 2011). - - :param data: data to slice - :param start: start date of slice - :param end: end date of slice - :param coord: name of axis to slice - - :return: sliced data - """ - return data.loc[{coord: slice(str(start), str(end))}] - - def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: - """ - Set all negative concentrations to zero. - - Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ - #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", - "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". - - :param data: data array containing variables to check - :param minimum: minimum value, by default this should be 0 - - :return: corrected data - """ - chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", - "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(self.variables)) - data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) - return data - - def get_transposed_history(self) -> xr.DataArray: - """Return history. - - :return: history with dimensions datetime, window, Stations, variables. - """ - return self.history.transpose("datetime", "window", "Stations", "variables").copy() - - def get_transposed_label(self) -> xr.DataArray: - """Return label. - - :return: label with dimensions datetime, window, Stations, variables. - """ - return self.label.squeeze("Stations").transpose("datetime", "window").copy() - - def get_extremes_history(self) -> xr.DataArray: - """Return extremes history. - - :return: extremes history with dimensions datetime, window, Stations, variables. - """ - return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy() - - def get_extremes_label(self) -> xr.DataArray: - """Return extremes label. - - :return: extremes label with dimensions datetime, window, Stations, variables. - """ - return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy() - - def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, - timedelta: Tuple[int, str] = (1, 'm')): - """ - Multiply extremes. - - This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can - also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of - floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised - space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be - extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is - used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can - identify those "artificial" data points later easily. Extreme inputs and labels are stored in - self.extremes_history and self.extreme_labels, respectively. - - :param extreme_values: user definition of extreme - :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values, - if True only extract values larger than extreme_values - :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime - """ - # check if labels or history is None - if (self.label is None) or (self.history is None): - logging.debug(f"{self.station} has `None' labels, skip multiply extremes") - return - - # check type if inputs - extreme_values = helpers.to_list(extreme_values) - for i in extreme_values: - if not isinstance(i, number.__args__): - raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " - f"{i} is type {type(i)}") - - for extr_val in sorted(extreme_values): - # check if some extreme values are already extracted - if (self.extremes_label is None) or (self.extremes_history is None): - # extract extremes based on occurance in labels - if extremes_on_right_tail_only: - extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1, ) - else: - extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1), - (self.label > extr_val).any(axis=0).values.reshape(-1, 1)), - axis=1).any(axis=1) - extremes_label = self.label[..., extreme_label_idx] - extremes_history = self.history[..., extreme_label_idx, :] - extremes_label.datetime.values += np.timedelta64(*timedelta) - extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_label = extremes_label # .squeeze('Stations').transpose('datetime', 'window') - self.extremes_history = extremes_history # .transpose('datetime', 'window', 'Stations', 'variables') - else: # one extr value iteration is done already: self.extremes_label is NOT None... - if extremes_on_right_tail_only: - extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, ) - else: - extreme_label_idx = np.concatenate( - ((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), - (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) - ), axis=1).any(axis=1) - # check on existing extracted extremes to minimise computational costs for comparison - extremes_label = self.extremes_label[..., extreme_label_idx] - extremes_history = self.extremes_history[..., extreme_label_idx, :] - extremes_label.datetime.values += np.timedelta64(*timedelta) - extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime') - self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime') - if __name__ == "__main__": - dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + dp = DataPrepJoin('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) print(dp) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index dedcda0a6a3ff8fb9246bc6efe097eeb6b463999..b97d28c1cf71d35526207450d6b0bb386ddefdb7 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -13,7 +13,7 @@ import numpy as np import pandas as pd import xarray as xr -from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrep +from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrepJoin from src.helpers.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking, statistics from src.model_modules.linear_model import OrdinaryLeastSquaredModel @@ -358,7 +358,7 @@ class PostProcessing(RunEnvironment): return getter.get(self._sampling, None) @staticmethod - def _create_observation(data: DataPrep, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, + def _create_observation(data: DataPrepJoin, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: """ Create observation as ground truth from given data. @@ -402,7 +402,7 @@ class PostProcessing(RunEnvironment): ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction - def _create_persistence_forecast(self, data: DataPrep, persistence_prediction: xr.DataArray, mean: xr.DataArray, + def _create_persistence_forecast(self, data: DataPrepJoin, persistence_prediction: xr.DataArray, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: """ Create persistence forecast with given data. diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index a8ca555c9748f7656fefc007922ee0d7df1992fa..85c2b6a7c256deb6bfcfbf73483652031d034a27 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -8,7 +8,8 @@ import pandas as pd import pytest import xarray as xr -from src.data_handling.data_preparation import DataPrep +# from src.data_handling.data_preparation import DataPrep +from src.data_handling.data_preparation_join import DataPrepJoin as DataPrep from src.helpers.join import EmptyQueryResult @@ -52,8 +53,9 @@ class TestDataPrep: meta_file = data_prep_no_init._set_meta_file_name() data_prep_no_init.kwargs = {"store_data_locally": False} data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - data_prep_no_init.download_data(file_name, meta_file) - assert isinstance(data_prep_no_init.data, xr.DataArray) + data, meta = data_prep_no_init.download_data(file_name, meta_file) + assert isinstance(data, xr.DataArray) + assert isinstance(meta, pd.DataFrame) def test_download_data_from_join(self, data_prep_no_init): file_name = data_prep_no_init._set_file_name() @@ -70,7 +72,8 @@ class TestDataPrep: meta_file = data_prep_no_init._set_meta_file_name() data_prep_no_init.kwargs = {"store_data_locally": False} data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - data_prep_no_init.download_data(file_name, meta_file) + _, meta = data_prep_no_init.download_data(file_name, meta_file) + data_prep_no_init.meta = meta assert data_prep_no_init.check_station_meta() is None data_prep_no_init.station_type = "traffic" with pytest.raises(FileNotFoundError) as e: @@ -83,8 +86,8 @@ class TestDataPrep: data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} file_path = data_prep_no_init._set_file_name() meta_file_path = data_prep_no_init._set_meta_file_name() - os.remove(file_path) - os.remove(meta_file_path) + os.remove(file_path) if os.path.exists(file_path) else None + os.remove(meta_file_path) if os.path.exists(meta_file_path) else None assert not os.path.exists(file_path) assert not os.path.exists(meta_file_path) data_prep_no_init.kwargs = {"overwrite_local_data": True}