diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py index 1bd380d35cae73ba7dee2c2a10214483ab0ed62d..9ce7307d87fea03c11066068d8eccd78a02ed0bf 100644 --- a/src/data_handling/__init__.py +++ b/src/data_handling/__init__.py @@ -10,9 +10,6 @@ __date__ = '2020-04-17' from .bootstraps import BootStraps -from .data_preparation_join import DataPrepJoin -from .data_generator import DataGenerator -from .data_distributor import Distributor from .iterator import KerasIterator, DataCollection -from .advanced_data_handling import DefaultDataPreparation -from .data_preparation import StationPrep \ No newline at end of file +from .advanced_data_handling import DefaultDataPreparation, AbstractDataPreparation +from .data_preparation_neighbors import DataPreparationNeighbors diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py index 6fb0c723f7af70941959bf46723c802ebb921139..63d26bd9d9077eac814d19265e6d009fb2073774 100644 --- a/src/data_handling/advanced_data_handling.py +++ b/src/data_handling/advanced_data_handling.py @@ -16,7 +16,7 @@ import inspect from typing import Union, List, Tuple import logging from functools import reduce -from src.data_handling.data_preparation import StationPrep +from src.data_handling.station_preparation import StationPrep from src.helpers.join import EmptyQueryResult diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py deleted file mode 100644 index 2600afcbd8948c26a2b4cf37329b424cac69f40a..0000000000000000000000000000000000000000 --- a/src/data_handling/data_distributor.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Data Distribution Module. - -How to use ----------- - -Create distributor object from a generator object and parse it to the fit generator method. Provide the number of -steps per epoch with distributor's length method. - -.. code-block:: python - - model = YourKerasModel() - data_generator = DataGenerator(*args, **kwargs) - data_distributor = Distributor(data_generator, model, **kwargs) - history = model.fit_generator(generator=data_distributor.distribute_on_batches(), - steps_per_epoch=len(data_distributor), - epochs=10,) - -Additionally, a validation data set can be parsed using the length and distribute methods. -""" - -from __future__ import generator_stop - -__author__ = "Lukas Leufen, Felix Kleinert" -__date__ = '2019-12-05' - -import math - -import keras -import numpy as np - -from src.data_handling.data_generator import DataGenerator - - -class Distributor(keras.utils.Sequence): - """Distribute data generator elements according to mini batch size.""" - - def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256, - permute_data: bool = False, upsampling: bool = False): - """ - Set up distributor. - - :param generator: The generator object must be iterable and return inputs and targets on each iteration - :param model: a keras model with one or more output branches - :param batch_size: batch size to use - :param permute_data: data is randomly permuted if enabled on each train step - :param upsampling: upsample data with upsample extremes data from generator object and shuffle data or use only - the standard input data. - """ - self.generator = generator - self.model = model - self.batch_size = batch_size - self.do_data_permutation = permute_data - self.upsampling = upsampling - - def _get_model_rank(self): - mod_out = self.model.output_shape - if isinstance(mod_out, tuple): - # only one output branch: (None, ahead) - mod_rank = 1 - elif isinstance(mod_out, list): - # multiple output branches, e.g.: [(None, ahead), (None, ahead)] - mod_rank = len(mod_out) - else: # pragma: no cover - raise TypeError("model output shape must either be tuple or list.") - return mod_rank - - def _get_number_of_mini_batches(self, values): - return math.ceil(values.shape[0] / self.batch_size) - - def _permute_data(self, x, y): - """ - Permute inputs x and labels y if permutation is enabled in instance. - - :param x: inputs - :param y: labels - :return: permuted or original data - """ - if self.do_data_permutation: - p = np.random.permutation(len(x)) # equiv to .shape[0] - x = x[p] - y = y[p] - return x, y - - def distribute_on_batches(self, fit_call=True): - """ - Create generator object to distribute mini batches. - - Split data from given generator object (usually for single station) according to the given batch size. Also - perform upsampling if enabled and random shuffling (either if data permutation is enabled or if upsampling is - enabled). Lastly multiply targets if provided model has multiple output branches. - - :param fit_call: switch to exit while loop after first iteration. This is used to determine the length of all - distributed mini batches. For default, fit_call is True to obtain infinite loop for training. - :return: yields next mini batch - """ - while True: - for k, v in enumerate(self.generator): - # get rank of output - mod_rank = self._get_model_rank() - # get data - x_total = np.copy(v[0]) - y_total = np.copy(v[1]) - if self.upsampling: - try: - s = self.generator.get_data_generator(k) - x_total = np.concatenate([x_total, np.copy(s.get_extremes_history())], axis=0) - y_total = np.concatenate([y_total, np.copy(s.get_extremes_label())], axis=0) - except AttributeError: # no extremes history / labels available, copy will fail - pass - # get number of mini batches - num_mini_batches = self._get_number_of_mini_batches(x_total) - # permute order for mini-batches - x_total, y_total = self._permute_data(x_total, y_total) - for prev, curr in enumerate(range(1, num_mini_batches + 1)): - x = x_total[prev * self.batch_size:curr * self.batch_size, ...] - y = [y_total[prev * self.batch_size:curr * self.batch_size, ...] for _ in range(mod_rank)] - if x is not None: # pragma: no branch - yield x, y - if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: - return - - def __len__(self) -> int: - """ - Total number of distributed mini batches. - - :return: the length of the distribute on batches object - """ - num_batch = 0 - for _ in self.distribute_on_batches(fit_call=False): - num_batch += 1 - return num_batch diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py deleted file mode 100644 index 8e14d019634d134d01b7edec92021aed20b59ecc..0000000000000000000000000000000000000000 --- a/src/data_handling/data_generator.py +++ /dev/null @@ -1,366 +0,0 @@ -"""Data Generator class to handle large arrays for machine learning.""" - -__author__ = 'Felix Kleinert, Lukas Leufen' -__date__ = '2019-11-07' - -import logging -import os -import pickle -from typing import Union, List, Tuple, Any, Dict - -import dask.array as da -import keras -import xarray as xr - -from src import helpers -from src.data_handling.data_preparation import AbstractDataPrep -from src.helpers.join import EmptyQueryResult - -number = Union[float, int] -num_or_list = Union[number, List[number]] -data_or_none = Union[xr.DataArray, None] - - -class DataGenerator(keras.utils.Sequence): - """ - This class is a generator to handle large arrays for machine learning. - - .. code-block:: python - - data_generator = DataGenerator(**args, **kwargs) - - Data generator item can be called manually by position (integer) or station id (string). Methods also accept lists - with exactly one entry of integer or string. - - .. code-block:: - - # select generator elements by position index - first_element = data_generator.get_data_generator([0]) # 1st element - n_element = data_generator.get_data_generator([4]) # 5th element - - # select by name - station_xy = data_generator.get_data_generator(["station_xy"]) # will raise KeyError if not available - - If used as iterator or directly called by get item method, the data generator class returns transposed labels and - history object from underlying data preparation class DataPrep. - - .. code-block:: python - - # select history and label by position - hist, labels = data_generator[0] - # by name - hist, labels = data_generator["station_xy"] - # as iterator - for (hist, labels) in data_generator: - pass - - This class can also be used with keras' fit_generator and predict_generator. Individual stations are the iterables. - """ - - def __init__(self, data_path: str, stations: Union[str, List[str]], variables: List[str], - interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, - interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, - data_preparation=None, **kwargs): - """ - Set up data generator. - - :param data_path: path to data - :param stations: list with all stations to include - :param variables: list with all used variables - :param interpolate_dim: dimension along which interpolation is applied - :param target_dim: dimension of target variable - :param target_var: name of target variable - :param station_type: TOAR station type classification (background, traffic) - :param interpolate_method: method of interpolation - :param limit_nan_fill: maximum gab in data to fill by interpolation - :param window_history_size: length of the history window - :param window_lead_time: lenght of the label window - :param transformation: transformation method to apply on data - :param extreme_values: set up the extreme value upsampling - :param kwargs: additional kwargs that are used in either DataPrep (transformation, start / stop period, ...) - or extreme values - """ - self.data_path = os.path.abspath(data_path) - self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") - if not os.path.exists(self.data_path_tmp): - os.makedirs(self.data_path_tmp) - self.stations = helpers.to_list(stations) - self.variables = variables - self.interpolate_dim = interpolate_dim - self.target_dim = target_dim - self.target_var = target_var - self.station_type = station_type - self.interpolate_method = interpolate_method - self.limit_nan_fill = limit_nan_fill - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time - self.extreme_values = extreme_values - self.DataPrep = data_preparation if data_preparation is not None else AbstractDataPrep - self.kwargs = kwargs - self.transformation = self.setup_transformation(transformation) - - def __repr__(self): - """Display all class attributes.""" - return f"DataGenerator(path='{self.data_path}', stations={self.stations}, " \ - f"variables={self.variables}, station_type={self.station_type}, " \ - f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \ - f"target_var='{self.target_var}', **{self.kwargs})" - - def __len__(self): - """Return the number of stations.""" - return len(self.stations) - - def __iter__(self) -> "DataGenerator": - """ - Define the __iter__ part of the iterator protocol to iterate through this generator. - - Sets the private attribute `_iterator` to 0. - """ - self._iterator = 0 - return self - - def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]: - """ - Get the data generator, and return the history and label data of this generator. - - This is the implementation of the __next__ method of the iterator protocol. - """ - if self._iterator < self.__len__(): - data = self.get_data_generator() - self._iterator += 1 - if data.history is not None and data.label is not None: # pragma: no branch - return data.get_transposed_history(), data.get_transposed_label() - else: - self.__next__() # pragma: no cover - else: - raise StopIteration - - def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]: - """ - Define the get item method for this generator. - - Retrieve data from generator and return history and labels. - - :param item: station key to choose the data generator. - :return: The generator's time series of history data and its labels - """ - data = self.get_data_generator(key=item) - return data.get_transposed_history(), data.get_transposed_label() - - def setup_transformation(self, transformation: Dict): - """ - Set up transformation by extracting all relevant information. - - Extract all information from transformation dictionary. Possible keys are scope. method, mean, and std. Scope - can either be station or data. Station scope means, that data transformation is performed for each station - independently (somehow like batch normalisation), whereas data scope means a transformation applied on the - entire data set. - - * If using data scope, mean and standard deviation (each only if required by transformation method) can either - be calculated accurate or as an estimate (faster implementation). This must be set in dictionary either - as "mean": "accurate" or "mean": "estimate". In both cases, the required statistics are calculated and saved. - After this calculations, the mean key is overwritten by the actual values to use. - * If using station scope, no additional information is required. - * If a transformation should be applied on base of existing values, these need to be provided in the respective - keys "mean" and "std" (again only if required for given method). - - :param transformation: the transformation dictionary as described above. - - :return: updated transformation dictionary - """ - if transformation is None: - return - transformation = transformation.copy() - scope = transformation.get("scope", "station") - method = transformation.get("method", "standardise") - mean = transformation.get("mean", None) - std = transformation.get("std", None) - if scope == "data": - if isinstance(mean, str): - if mean == "accurate": - mean, std = self.calculate_accurate_transformation(method) - elif mean == "estimate": - mean, std = self.calculate_estimated_transformation(method) - else: - raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or" - f"be an array with already calculated means. Given was: {mean}") - elif scope == "station": - mean, std = None, None - else: - raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}") - transformation["method"] = method - transformation["mean"] = mean - transformation["std"] = std - return transformation - - def calculate_accurate_transformation(self, method: str) -> Tuple[data_or_none, data_or_none]: - """ - Calculate accurate transformation statistics. - - Use all stations of this generator and calculate mean and standard deviation on entire data set using dask. - Because there can be much data, this can take a while. - - :param method: name of transformation method - - :return: accurate calculated mean and std (depending on transformation) - """ - tmp = [] - mean = None - std = None - for station in self.stations: - try: - data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, - **self.kwargs) - chunks = (1, 100, data.data.shape[2]) - tmp.append(da.from_array(data.data.data, chunks=chunks)) - except EmptyQueryResult: - continue - tmp = da.concatenate(tmp, axis=1) - if method in ["standardise", "centre"]: - mean = da.nanmean(tmp, axis=1).compute() - mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) - if method == "standardise": - std = da.nanstd(tmp, axis=1).compute() - std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) - else: - raise NotImplementedError - return mean, std - - def calculate_estimated_transformation(self, method): - """ - Calculate estimated transformation statistics. - - Use all stations of this generator and calculate mean and standard deviation first for each station separately. - Afterwards, calculate the average mean and standard devation as estimated statistics. Because this method does - not consider the length of each data set, the estimated mean distinguishes from the real data mean. Furthermore, - the estimated standard deviation is assumed to be the mean (also not weighted) of all deviations. But this is - mathematically not true, but still a rough and faster estimation of the true standard deviation. Do not use this - method for further statistical calculation. However, in the scope of data preparation for machine learning, this - approach is decent ("it is just scaling"). - - :param method: name of transformation method - - :return: accurate calculated mean and std (depending on transformation) - """ - data = [[]] * len(self.variables) - coords = {"variables": self.variables, "Stations": range(0)} - mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) - std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) - for station in self.stations: - try: - data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, - **self.kwargs) - data.transform("datetime", method=method) - mean = mean.combine_first(data.mean) - std = std.combine_first(data.std) - data.transform("datetime", method=method, inverse=True) - except EmptyQueryResult: - continue - return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None - - def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, - save_local_tmp_storage: bool = True) -> AbstractDataPrep: - """ - Create DataPrep object and preprocess data for given key. - - Select data for given key, create a DataPrep object and - * apply transformation (optional) - * interpolate - * make history, labels, and observation - * remove nans - * upsample extremes (optional). - Processed data can be stored locally in a .pickle file. If load local tmp storage is enabled, the get data - generator tries first to load data from local pickle file and only creates a new DataPrep object if it couldn't - load this data from disk. - - :param key: station key to choose the data generator. - :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data - from tmp pickle file to save computational time (but of course more disk space required). - :param save_local_tmp_storage: save processed data as temporal file locally (default True) - - :return: preprocessed data as a DataPrep instance - """ - station = self.get_station_key(key) - try: - if not load_local_tmp_storage: - raise FileNotFoundError - data = self._load_pickle_data(station, self.variables) - except FileNotFoundError: - logging.debug(f"load not pickle data for {station}") - data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, - **self.kwargs) - if self.transformation is not None: - data.transform("datetime", **helpers.remove_items(self.transformation, "scope")) - data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) - data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) - data.remove_nan(self.interpolate_dim) - if self.extreme_values is not None: - kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)} - data.multiply_extremes(self.extreme_values, **kwargs) - if save_local_tmp_storage: - self._save_pickle_data(data) - return data - - def _save_pickle_data(self, data: Any): - """ - Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle'. - - :param data: any data, that should be saved - """ - date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" - vars = '_'.join(sorted(data.variables)) - station = ''.join(data.station) - file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle") - with open(file, "wb") as f: - pickle.dump(data, f) - logging.debug(f"save pickle data to {file}") - - def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any: - """ - Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'. - - :param station: station to load - :param variables: list of variables to load - :return: loaded data - """ - date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" - vars = '_'.join(sorted(variables)) - station = ''.join(station) - file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle") - with open(file, "rb") as f: - data = pickle.load(f) - logging.debug(f"load pickle data from {file}") - return data - - def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str: - """ - Return a valid station key or raise KeyError if this wasn't possible. - - :param key: station key to choose the data generator. - :return: station key (id from database) - """ - # extract value if given as list - if isinstance(key, list): - if len(key) == 1: - key = key[0] - else: - raise KeyError(f"More than one key was given: {key}") - # return station name either from key or the recent element from iterator - if key is None: - return self.stations[self._iterator] - else: - if isinstance(key, int): - if key < self.__len__(): - return self.stations[key] - else: - raise KeyError(f"{key} is not in range(0, {self.__len__()})") - elif isinstance(key, str): - if key in self.stations: - return key - else: - raise KeyError(f"{key} is not in stations") - else: - raise KeyError(f"Key has to be from Union[str, int]. Given was {key} ({type(key)})") diff --git a/src/data_handling/data_preparation_join.py b/src/data_handling/data_preparation_join.py deleted file mode 100644 index 86c7dee055c8258069307567b28ffcd113e13477..0000000000000000000000000000000000000000 --- a/src/data_handling/data_preparation_join.py +++ /dev/null @@ -1,124 +0,0 @@ -"""Data Preparation class to handle data processing for machine learning.""" - -__author__ = 'Felix Kleinert, Lukas Leufen' -__date__ = '2019-10-16' - -import datetime as dt -import inspect -import logging -from typing import Union, List - -import pandas as pd -import xarray as xr - -from src import helpers -from src.helpers import join -from src.data_handling.data_preparation import AbstractDataPrep - -# define a more general date type for type hinting -date = Union[dt.date, dt.datetime] -str_or_list = Union[str, List[str]] -number = Union[float, int] -num_or_list = Union[number, List[number]] -data_or_none = Union[xr.DataArray, None] - - -class DataPrepJoin(AbstractDataPrep): - """ - This class prepares data to be used in neural networks. - - The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep - instance will load data from TOAR database and store this data locally to use the next time. For the moment, there - is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable. - - After data loading, different data pre-processing steps can be executed to prepare the data for further - applications. Especially the following methods can be used for the pre-processing step: - - - interpolate: interpolate between data points by using xarray's interpolation method - - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \ - interval [0, 1] are not implemented yet. - - make window history: represent the history (time steps before) for training/ testing; X - - make labels: create target vector with given leading time steps for training/ testing; y - - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \ - Use this method after the creation of the window history and labels to clean up the data cube. - - To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA, - "Umweltbundesamt") and the variables to use. Further options can be set in the instance. - - * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable. - * `start`: define a start date for the data cube creation. Default: Use the first entry in time series - * `end`: set the end date for the data cube. Default: Use last date in time series. - * `store_data_locally`: store recently downloaded data on local disk. Default: True - * set further parameters for xarray's interpolation methods to modify the interpolation scheme - - """ - - def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], network: str = None, - station_type: str = None, **kwargs): - self.network = network - self.station_type = station_type - params = helpers.remove_items(inspect.getfullargspec(AbstractDataPrep.__init__).args, "self") - kwargs = {**{k: v for k, v in locals().items() if k in params and v is not None}, **kwargs} - super().__init__(**kwargs) - - def download_data(self, file_name, meta_file): - """ - Download data and meta from join. - - :param file_name: name of file to save data to (containing full path) - :param meta_file: name of the meta data file (also containing full path) - """ - data, meta = self.download_data_from_join(file_name, meta_file) - return data, meta - - def check_station_meta(self): - """ - Search for the entries in meta data and compare the value with the requested values. - - Will raise a FileNotFoundError if the values mismatch. - """ - if self.station_type is not None: - check_dict = {"station_type": self.station_type, "network_name": self.network} - for (k, v) in check_dict.items(): - if v is None: - continue - if self.meta.at[k, self.station[0]] != v: - logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " - f"grapping from web.") - raise FileNotFoundError - - def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: - """ - Download data from TOAR database using the JOIN interface. - - Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally - stored locally using given names for file and meta file. - - :param file_name: name of file to save data to (containing full path) - :param meta_file: name of the meta data file (also containing full path) - - :return: downloaded data and its meta data - """ - df_all = {} - df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, - station_type=self.station_type, network_name=self.network, sampling=self.sampling) - df_all[self.station[0]] = df - # convert df_all to xarray - xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} - xarr = xr.Dataset(xarr).to_array(dim='Stations') - if self.kwargs.get('store_data_locally', True): - # save locally as nc/csv file - xarr.to_netcdf(path=file_name) - meta.to_csv(meta_file) - return xarr, meta - - def __repr__(self): - """Represent class attributes.""" - return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ - f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})" - - -if __name__ == "__main__": - dp = DataPrepJoin('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) - print(dp) diff --git a/src/data_handling/data_preparation_neighbors.py b/src/data_handling/data_preparation_neighbors.py index 855ba3c04f455171a81f7dc4595f8b8c64409a87..fa5744e732adbfe6488705f73e24e8376116c5bc 100644 --- a/src/data_handling/data_preparation_neighbors.py +++ b/src/data_handling/data_preparation_neighbors.py @@ -4,7 +4,7 @@ __date__ = '2020-07-17' from src.helpers import to_list -from src.data_handling.data_preparation import StationPrep +from src.data_handling.station_preparation import StationPrep from src.data_handling.advanced_data_handling import DefaultDataPreparation import os diff --git a/src/data_handling/data_preparation.py b/src/data_handling/station_preparation.py similarity index 52% rename from src/data_handling/data_preparation.py rename to src/data_handling/station_preparation.py index bff3b9f12f11d481ea70a470a14795d7bce807b5..da8c3ad83bc3c794e540863f6343b7337484ee7d 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/station_preparation.py @@ -8,12 +8,10 @@ import logging import os from functools import reduce from typing import Union, List, Iterable, Tuple, Dict -from src.helpers.join import EmptyQueryResult import numpy as np import pandas as pd import xarray as xr -import dask.array as da from src.configuration import check_path_and_create from src import helpers @@ -635,534 +633,6 @@ class StationPrep(AbstractStationPrep): return mean, std, self._transform_method -class AbstractDataPrep(object): - """ - This class prepares data to be used in neural networks. - - The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep - instance will load data from TOAR database and store this data locally to use the next time. For the moment, there - is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable. - - After data loading, different data pre-processing steps can be executed to prepare the data for further - applications. Especially the following methods can be used for the pre-processing step: - - - interpolate: interpolate between data points by using xarray's interpolation method - - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \ - interval [0, 1] are not implemented yet. - - make window history: represent the history (time steps before) for training/ testing; X - - make labels: create target vector with given leading time steps for training/ testing; y - - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \ - Use this method after the creation of the window history and labels to clean up the data cube. - - To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA, - "Umweltbundesamt") and the variables to use. Further options can be set in the instance. - - * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable. - * `start`: define a start date for the data cube creation. Default: Use the first entry in time series - * `end`: set the end date for the data cube. Default: Use last date in time series. - * `store_data_locally`: store recently downloaded data on local disk. Default: True - * set further parameters for xarray's interpolation methods to modify the interpolation scheme - - """ - - def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], **kwargs): - """Construct instance.""" - self.path = os.path.abspath(path) - self.station = helpers.to_list(station) - self.variables = variables - self.mean: data_or_none = None - self.std: data_or_none = None - self.history: data_or_none = None - self.label: data_or_none = None - self.observation: data_or_none = None - self.extremes_history: data_or_none = None - self.extremes_label: data_or_none = None - self.kwargs = kwargs - self.data = None - self.meta = None - self._transform_method = None - self.statistics_per_var = kwargs.get("statistics_per_var", None) - self.sampling = kwargs.get("sampling", "daily") - if self.statistics_per_var is not None or self.sampling == "hourly": - self.load_data() - else: - raise NotImplementedError("Either select hourly data or provide statistics_per_var.") - - def load_data(self, source_name=""): - """ - Load data and meta data either from local disk (preferred) or download new data by using a custom download method. - - Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both - cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not - set, it is assumed, that data should be saved locally. - """ - source_name = source_name if len(source_name) == 0 else f" from {source_name}" - check_path_and_create(self.path) - file_name = self._set_file_name() - meta_file = self._set_meta_file_name() - if self.kwargs.get('overwrite_local_data', False): - logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") - if os.path.exists(file_name): - os.remove(file_name) - if os.path.exists(meta_file): - os.remove(meta_file) - data, self.meta = self.download_data(file_name, meta_file) - logging.debug(f"loaded new data{source_name}") - else: - try: - logging.debug(f"try to load local data from: {file_name}") - data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta() - logging.debug("loading finished") - except FileNotFoundError as e: - logging.debug(e) - logging.debug(f"load new data{source_name}") - data, self.meta = self.download_data(file_name, meta_file) - logging.debug("loading finished") - # create slices and check for negative concentration. - data = self._slice_prep(data) - self.data = self.check_for_negative_concentrations(data) - - def download_data(self, file_name, meta_file) -> [xr.DataArray, pd.DataFrame]: - """ - Download data and meta. - - :param file_name: name of file to save data to (containing full path) - :param meta_file: name of the meta data file (also containing full path) - """ - raise NotImplementedError - - def check_station_meta(self): - """ - Placeholder function to implement some additional station meta data check if desired. - - Ideally, this method should raise a FileNotFoundError if a value mismatch to load fresh data from a source. If - this method is not required for your application just inherit and add the `pass` command inside the method. The - NotImplementedError is more a reminder that you could use it. - """ - raise NotImplementedError - - def _set_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc") - - def _set_meta_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv") - - def __repr__(self): - """Represent class attributes.""" - return f"AbstractDataPrep(path='{self.path}', station={self.station}, variables={self.variables}, " \ - f"**{self.kwargs})" - - def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, - **kwargs): - """ - Interpolate values according to different methods. - - (Copy paste from dataarray.interpolate_na) - - :param dim: - Specifies the dimension along which to interpolate. - :param method: - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', - 'polynomial', 'barycentric', 'krog', 'pchip', - 'spline', 'akima'}, optional - String indicating which method to use for interpolation: - - - 'linear': linear interpolation (Default). Additional keyword - arguments are passed to ``numpy.interp`` - - 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', - 'polynomial': are passed to ``scipy.interpolate.interp1d``. If - method=='polynomial', the ``order`` keyword argument must also be - provided. - - 'barycentric', 'krog', 'pchip', 'spline', and `akima`: use their - respective``scipy.interpolate`` classes. - :param limit: - default None - Maximum number of consecutive NaNs to fill. Must be greater than 0 - or None for no limit. - :param use_coordinate: - default True - Specifies which index to use as the x values in the interpolation - formulated as `y = f(x)`. If False, values are treated as if - eqaully-spaced along `dim`. If True, the IndexVariable `dim` is - used. If use_coordinate is a string, it specifies the name of a - coordinate variariable to use as the index. - :param kwargs: - - :return: xarray.DataArray - """ - self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs) - - @staticmethod - def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None: - """ - Support inverse_transformation method. - - Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas - normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements. - - :param mean: data with all mean values - :param std: data with all standard deviation values - :param method: name of transformation method - """ - msg = "" - if method in ['standardise', 'centre'] and mean is None: - msg += "mean, " - if method == 'standardise' and std is None: - msg += "std, " - if len(msg) > 0: - raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") - - def inverse_transform(self) -> None: - """ - Perform inverse transformation. - - Will raise an AssertionError, if no transformation was performed before. Checks first, if all required - statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by - new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the - current data is not transformed. - """ - - def f_inverse(data, mean, std, method_inverse): - if method_inverse == 'standardise': - return statistics.standardise_inverse(data, mean, std), None, None - elif method_inverse == 'centre': - return statistics.centre_inverse(data, mean), None, None - elif method_inverse == 'normalise': - raise NotImplementedError - else: - raise NotImplementedError - - if self._transform_method is None: - raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.") - self.check_inverse_transform_params(self.mean, self.std, self._transform_method) - self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method) - self._transform_method = None - - def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None, - std=None) -> None: - """ - Transform data according to given transformation settings. - - This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 - and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale - (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This - method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the - internal transform method, internal mean and internal standard deviation weren't set ('inverse=True'). - - :param string/int dim: This param is not used for inverse transformation. - | for xarray.DataArray as string: name of dimension which should be standardised - | for pandas.DataFrame as int: axis of dimension which should be standardised - :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented - yet. This param is not used for inverse transformation. - :param inverse: Switch between transformation and inverse transformation. - - :return: xarray.DataArrays or pandas.DataFrames: - #. mean: Mean of data - #. std: Standard deviation of data - #. data: Standardised data - """ - - def f(data): - if method == 'standardise': - return statistics.standardise(data, dim) - elif method == 'centre': - return statistics.centre(data, dim) - elif method == 'normalise': - # use min/max of data or given min/max - raise NotImplementedError - else: - raise NotImplementedError - - def f_apply(data): - if method == "standardise": - return mean, std, statistics.standardise_apply(data, mean, std) - elif method == "centre": - return mean, None, statistics.centre_apply(data, mean) - else: - raise NotImplementedError - - if not inverse: - if self._transform_method is not None: - raise AssertionError(f"Transform method is already set. Therefore, data was already transformed with " - f"{self._transform_method}. Please perform inverse transformation of data first.") - self.mean, self.std, self.data = locals()["f" if mean is None else "f_apply"](self.data) - self._transform_method = method - else: - self.inverse_transform() - - def get_transformation_information(self, variable: str) -> Tuple[data_or_none, data_or_none, str]: - """ - Extract transformation statistics and method. - - Get mean and standard deviation for given variable and the transformation method if set. If a transformation - depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are - returned with None as fill value. - - :param variable: Variable for which the information on transformation is requested. - - :return: mean, standard deviation and transformation method - """ - try: - mean = self.mean.sel({'variables': variable}).values - except AttributeError: - mean = None - try: - std = self.std.sel({'variables': variable}).values - except AttributeError: - std = None - return mean, std, self._transform_method - - def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing history data. - - Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted - data. This is used to represent history in the data. Results are stored in history attribute. - - :param dim_name_of_inputs: Name of dimension which contains the input variables - :param window: number of time steps to look back in history - Note: window will be treated as negative value. This should be in agreement with looking back on - a time line. Nonetheless positive values are allowed but they are converted to its negative - expression - :param dim_name_of_shift: Dimension along shift will be applied - """ - window = -abs(window) - self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables}) - - def shift(self, dim: str, window: int) -> xr.DataArray: - """ - Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). - - :param dim: dimension along shift is applied - :param window: number of steps to shift (corresponds to the window length) - - :return: shifted data - """ - start = 1 - end = 1 - if window <= 0: - start = window - else: - end = window + 1 - res = [] - for w in range(start, end): - res.append(self.data.shift({dim: -w})) - window_array = self.create_index_array('window', range(start, end)) - res = xr.concat(res, dim=window_array) - return res - - def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, - window: int) -> None: - """ - Create a xr.DataArray containing labels. - - Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label - attribute. - - :param dim_name_of_target: Name of dimension which contains the target variable - :param target_var: Name of target variable in 'dimension' - :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied - :param window: lead time of label - """ - window = abs(window) - self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var}) - - def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing observations. - - Observations are defined as value of the current time step t. Set observation attribute. - - :param dim_name_of_target: Name of dimension which contains the observation variable - :param target_var: Name of observation variable(s) in 'dimension' - :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied - """ - self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var}) - - def remove_nan(self, dim: str) -> None: - """ - Remove all NAs slices along dim which contain nans in history, label and observation. - - This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute. - - :param dim: dimension along the remove is performed. - """ - intersect = [] - if (self.history is not None) and (self.label is not None): - non_nan_history = self.history.dropna(dim=dim) - non_nan_label = self.label.dropna(dim=dim) - non_nan_observation = self.observation.dropna(dim=dim) - intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, - non_nan_observation.coords[dim].values)) - - min_length = self.kwargs.get("min_length", 0) - if len(intersect) < max(min_length, 1): - self.history = None - self.label = None - self.observation = None - else: - self.history = self.history.sel({dim: intersect}) - self.label = self.label.sel({dim: intersect}) - self.observation = self.observation.sel({dim: intersect}) - - @staticmethod - def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray: - """ - Create an 1D xr.DataArray with given index name and value. - - :param index_name: name of dimension - :param index_value: values of this dimension - - :return: this array - """ - ind = pd.DataFrame({'val': index_value}, index=index_value) - res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True) - res.name = index_name - return res - - def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray: - """ - Set start and end date for slicing and execute self._slice(). - - :param data: data to slice - :param coord: name of axis to slice - - :return: sliced data - """ - start = self.kwargs.get('start', data.coords[coord][0].values) - end = self.kwargs.get('end', data.coords[coord][-1].values) - return self._slice(data, start, end, coord) - - @staticmethod - def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: - """ - Slice through a given data_item (for example select only values of 2011). - - :param data: data to slice - :param start: start date of slice - :param end: end date of slice - :param coord: name of axis to slice - - :return: sliced data - """ - return data.loc[{coord: slice(str(start), str(end))}] - - def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: - """ - Set all negative concentrations to zero. - - Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ - #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", - "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". - - :param data: data array containing variables to check - :param minimum: minimum value, by default this should be 0 - - :return: corrected data - """ - chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", - "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(self.variables)) - data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) - return data - - def get_transposed_history(self) -> xr.DataArray: - """Return history. - - :return: history with dimensions datetime, window, Stations, variables. - """ - return self.history.transpose("datetime", "window", "Stations", "variables").copy() - - def get_transposed_label(self) -> xr.DataArray: - """Return label. - - :return: label with dimensions datetime, window, Stations, variables. - """ - return self.label.squeeze("Stations").transpose("datetime", "window").copy() - - def get_extremes_history(self) -> xr.DataArray: - """Return extremes history. - - :return: extremes history with dimensions datetime, window, Stations, variables. - """ - return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy() - - def get_extremes_label(self) -> xr.DataArray: - """Return extremes label. - - :return: extremes label with dimensions datetime, window, Stations, variables. - """ - return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy() - - def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, - timedelta: Tuple[int, str] = (1, 'm')): - """ - Multiply extremes. - - This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can - also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of - floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised - space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be - extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is - used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can - identify those "artificial" data points later easily. Extreme inputs and labels are stored in - self.extremes_history and self.extreme_labels, respectively. - - :param extreme_values: user definition of extreme - :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values, - if True only extract values larger than extreme_values - :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime - """ - # check if labels or history is None - if (self.label is None) or (self.history is None): - logging.debug(f"{self.station} has `None' labels, skip multiply extremes") - return - - # check type if inputs - extreme_values = helpers.to_list(extreme_values) - for i in extreme_values: - if not isinstance(i, number.__args__): - raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " - f"{i} is type {type(i)}") - - for extr_val in sorted(extreme_values): - # check if some extreme values are already extracted - if (self.extremes_label is None) or (self.extremes_history is None): - # extract extremes based on occurance in labels - if extremes_on_right_tail_only: - extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1, ) - else: - extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1), - (self.label > extr_val).any(axis=0).values.reshape(-1, 1)), - axis=1).any(axis=1) - extremes_label = self.label[..., extreme_label_idx] - extremes_history = self.history[..., extreme_label_idx, :] - extremes_label.datetime.values += np.timedelta64(*timedelta) - extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_label = extremes_label # .squeeze('Stations').transpose('datetime', 'window') - self.extremes_history = extremes_history # .transpose('datetime', 'window', 'Stations', 'variables') - else: # one extr value iteration is done already: self.extremes_label is NOT None... - if extremes_on_right_tail_only: - extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, ) - else: - extreme_label_idx = np.concatenate( - ((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), - (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) - ), axis=1).any(axis=1) - # check on existing extracted extremes to minimise computational costs for comparison - extremes_label = self.extremes_label[..., extreme_label_idx] - extremes_history = self.extremes_history[..., extreme_label_idx, :] - extremes_label.datetime.values += np.timedelta64(*timedelta) - extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_label = xr.concat([self.extremes_label, extremes_label], dim='datetime') - self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime') - - if __name__ == "__main__": # dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) # print(dp) diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index 839b02203b22c2f5538613601aa125ed30455b0b..0c4e00051489aec5c6762d3221fca43ee5a39cf3 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -7,8 +7,7 @@ import numpy as np import pytest import xarray as xr -from src.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator -from src.data_handling.data_generator import DataGenerator +from src.data_handling.bootstraps import BootStraps from src.data_handling import DataPrepJoin diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py deleted file mode 100644 index 43c61be2134d68e1f81ed50420e2a801c9e63646..0000000000000000000000000000000000000000 --- a/test/test_data_handling/test_data_distributor.py +++ /dev/null @@ -1,121 +0,0 @@ -import math -import os - -import keras -import numpy as np -import pytest - -from src.data_handling.data_distributor import Distributor -from src.data_handling.data_generator import DataGenerator -from src.data_handling import DataPrepJoin -from test.test_modules.test_training import my_test_model - - -class TestDistributor: - - @pytest.fixture - def generator(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, - data_preparation=DataPrepJoin) - - @pytest.fixture - def generator_two_stations(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107', 'DEBW013'], - ['o3', 'temp'], 'datetime', 'variables', 'o3', - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, - data_preparation=DataPrepJoin) - - @pytest.fixture - def model(self): - return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False) - - @pytest.fixture - def model_with_minor_branch(self): - return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True) - - @pytest.fixture - def distributor(self, generator, model): - return Distributor(generator, model) - - def test_init_defaults(self, distributor): - assert distributor.batch_size == 256 - assert distributor.do_data_permutation is False - - def test_get_model_rank(self, distributor, model_with_minor_branch): - assert distributor._get_model_rank() == 1 - distributor.model = model_with_minor_branch - assert distributor._get_model_rank() == 2 - distributor.model = 1 - - def test_get_number_of_mini_batches(self, distributor): - values = np.zeros((2311, 19)) - assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) - - def test_distribute_on_batches_single_loop(self, generator_two_stations, model): - d = Distributor(generator_two_stations, model) - for e in d.distribute_on_batches(fit_call=False): - assert e[0].shape[0] <= d.batch_size - - def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model): - d = Distributor(generator_two_stations, model) - elements = [] - for i, e in enumerate(d.distribute_on_batches()): - if i < len(d): - elements.append(e[0]) - elif i == 2 * len(d): # check if all elements are repeated - assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None - else: # break when 3rd iteration starts (is called as infinite loop) - break - - def test_len(self, distributor): - assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256) - - def test_len_two_stations(self, generator_two_stations, model): - gen = generator_two_stations - d = Distributor(gen, model) - expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) - assert len(d) == expected - - def test_permute_data_no_permutation(self, distributor): - x = np.array(range(20)).reshape(2, 10).T - y = np.array(range(10)).reshape(10, 1) - x_perm, y_perm = distributor._permute_data(x, y) - assert np.testing.assert_equal(x, x_perm) is None - assert np.testing.assert_equal(y, y_perm) is None - - def test_permute_data(self, distributor): - x = np.array(range(20)).reshape(2, 10).T - y = np.array(range(10)).reshape(10, 1) - distributor.do_data_permutation = True - x_perm, y_perm = distributor._permute_data(x, y) - assert x_perm[0, 0] == y_perm[0] - assert x_perm[0, 1] == y_perm[0] + 10 - assert x_perm[5, 0] == y_perm[5] - assert x_perm[5, 1] == y_perm[5] + 10 - assert x_perm[-1, 0] == y_perm[-1] - assert x_perm[-1, 1] == y_perm[-1] + 10 - # resort x_perm and compare if equal to x - x_perm.sort(axis=0) - y_perm.sort(axis=0) - assert np.testing.assert_equal(x, x_perm) is None - assert np.testing.assert_equal(y, y_perm) is None - - def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model): - d = Distributor(generator, model, upsampling=True) - gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] - num_mini_batches = math.ceil(gen_len / d.batch_size) - i = 0 - for i, e in enumerate(d.distribute_on_batches(fit_call=False)): - assert e[0].shape[0] <= d.batch_size - assert i + 1 == num_mini_batches - - def test_distribute_on_batches_upsampling(self, generator, model): - generator.extreme_values = [1] - d = Distributor(generator, model, upsampling=True) - gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] - extr_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_extremes_label().shape[0] - i = 0 - for i, e in enumerate(d.distribute_on_batches(fit_call=False)): - assert e[0].shape[0] <= d.batch_size - assert i + 1 == math.ceil((gen_len + extr_len) / d.batch_size) diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 3144bde3440d861e109c4a3b0da8b77d317faa2b..4a113e842e7d75795cbed73ea37cd94079c983ba 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -6,7 +6,6 @@ import numpy as np import pytest import xarray as xr -from src.data_handling.data_generator import DataGenerator from src.data_handling import DataPrepJoin from src.helpers.join import EmptyQueryResult diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index d58c1a973ec474b2ec786271dff9d35ce5ca94d9..592a15ca6727a64b648095b610535e08c7aae751 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -10,8 +10,6 @@ import pytest from keras.callbacks import History from src.data_handling import DataPrepJoin -from src.data_handling.data_distributor import Distributor -from src.data_handling.data_generator import DataGenerator from src.helpers import PyTestRegex from src.model_modules.flatten import flatten_tail from src.model_modules.inception_model import InceptionModelBase