diff --git a/mlair/data_handling/__init__.py b/mlair/data_handling/__init__.py index 5139d13da91025dad15db04fdfea34309c4e28ff..cb5aa5db0f29cf51d32ed54e810fa9b363d80cc6 100644 --- a/mlair/data_handling/__init__.py +++ b/mlair/data_handling/__init__.py @@ -10,6 +10,6 @@ __date__ = '2020-04-17' from .bootstraps import BootStraps -from .data_preparation import DataPrep +from .data_preparation_join import DataPrepJoin from .data_generator import DataGenerator from .data_distributor import Distributor diff --git a/mlair/data_handling/data_generator.py b/mlair/data_handling/data_generator.py index 4c61454b0c751eeb6338f6f4e9072cabe14379a1..c0a32771c292ea7b004839877904ba897519c2f8 100644 --- a/mlair/data_handling/data_generator.py +++ b/mlair/data_handling/data_generator.py @@ -13,7 +13,7 @@ import keras import xarray as xr from mlair import helpers -from mlair.data_handling.data_preparation import DataPrep +from mlair.data_handling.data_preparation import AbstractDataPrep from mlair.helpers.join import EmptyQueryResult number = Union[float, int] @@ -57,15 +57,15 @@ class DataGenerator(keras.utils.Sequence): This class can also be used with keras' fit_generator and predict_generator. Individual stations are the iterables. """ - def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], + def __init__(self, data_path: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs): + window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, + data_preparation=None, **kwargs): """ Set up data generator. :param data_path: path to data - :param network: the observational network, the data should come from :param stations: list with all stations to include :param variables: list with all used variables :param interpolate_dim: dimension along which interpolation is applied @@ -85,7 +85,6 @@ class DataGenerator(keras.utils.Sequence): self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): os.makedirs(self.data_path_tmp) - self.network = network self.stations = helpers.to_list(stations) self.variables = variables self.interpolate_dim = interpolate_dim @@ -97,12 +96,13 @@ class DataGenerator(keras.utils.Sequence): self.window_history_size = window_history_size self.window_lead_time = window_lead_time self.extreme_values = extreme_values + self.DataPrep = data_preparation if data_preparation is not None else AbstractDataPrep self.kwargs = kwargs self.transformation = self.setup_transformation(transformation) def __repr__(self): """Display all class attributes.""" - return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \ + return f"DataGenerator(path='{self.data_path}', stations={self.stations}, " \ f"variables={self.variables}, station_type={self.station_type}, " \ f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \ f"target_var='{self.target_var}', **{self.kwargs})" @@ -210,8 +210,8 @@ class DataGenerator(keras.utils.Sequence): std = None for station in self.stations: try: - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) chunks = (1, 100, data.data.shape[2]) tmp.append(da.from_array(data.data.data, chunks=chunks)) except EmptyQueryResult: @@ -249,8 +249,8 @@ class DataGenerator(keras.utils.Sequence): std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) for station in self.stations: try: - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) data.transform("datetime", method=method) mean = mean.combine_first(data.mean) std = std.combine_first(data.std) @@ -260,7 +260,7 @@ class DataGenerator(keras.utils.Sequence): return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, - save_local_tmp_storage: bool = True) -> DataPrep: + save_local_tmp_storage: bool = True) -> AbstractDataPrep: """ Create DataPrep object and preprocess data for given key. @@ -288,8 +288,8 @@ class DataGenerator(keras.utils.Sequence): data = self._load_pickle_data(station, self.variables) except FileNotFoundError: logging.debug(f"load not pickle data for {station}") - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) if self.transformation is not None: data.transform("datetime", **helpers.remove_items(self.transformation, "scope")) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) diff --git a/mlair/data_handling/data_preparation.py b/mlair/data_handling/data_preparation.py index f500adec5c7a2d7fac67f3e6e9ba2fc61079c115..1dce5c87c2b076621ee08ae0f18906fd47d95e95 100644 --- a/mlair/data_handling/data_preparation.py +++ b/mlair/data_handling/data_preparation.py @@ -1,7 +1,7 @@ """Data Preparation class to handle data processing for machine learning.""" -__author__ = 'Felix Kleinert, Lukas Leufen' -__date__ = '2019-10-16' +__author__ = 'Lukas Leufen' +__date__ = '2020-06-29' import datetime as dt import logging @@ -25,7 +25,7 @@ num_or_list = Union[number, List[number]] data_or_none = Union[xr.DataArray, None] -class DataPrep(object): +class AbstractDataPrep(object): """ This class prepares data to be used in neural networks. @@ -55,14 +55,11 @@ class DataPrep(object): """ - def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str], - station_type: str = None, **kwargs): + def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], **kwargs): """Construct instance.""" self.path = os.path.abspath(path) - self.network = network self.station = helpers.to_list(station) self.variables = variables - self.station_type = station_type self.mean: data_or_none = None self.std: data_or_none = None self.history: data_or_none = None @@ -81,92 +78,60 @@ class DataPrep(object): else: raise NotImplementedError("Either select hourly data or provide statistics_per_var.") - def load_data(self): + def load_data(self, source_name=""): """ - Load data and meta data either from local disk (preferred) or download new data from TOAR database. + Load data and meta data either from local disk (preferred) or download new data by using a custom download method. Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not set, it is assumed, that data should be saved locally. """ + source_name = source_name if len(source_name) == 0 else f" from {source_name}" check_path_and_create(self.path) file_name = self._set_file_name() meta_file = self._set_meta_file_name() if self.kwargs.get('overwrite_local_data', False): - logging.debug(f"overwrite_local_data is true, therefore reload {file_name} from JOIN") + logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") if os.path.exists(file_name): os.remove(file_name) if os.path.exists(meta_file): os.remove(meta_file) - self.download_data(file_name, meta_file) - logging.debug("loaded new data from JOIN") + data, self.meta = self.download_data(file_name, meta_file) + logging.debug(f"loaded new data{source_name}") else: try: logging.debug(f"try to load local data from: {file_name}") - data = self._slice_prep(xr.open_dataarray(file_name)) - self.data = self.check_for_negative_concentrations(data) + data = xr.open_dataarray(file_name) self.meta = pd.read_csv(meta_file, index_col=0) - if self.station_type is not None: - self.check_station_meta() + self.check_station_meta() logging.debug("loading finished") except FileNotFoundError as e: logging.debug(e) - self.download_data(file_name, meta_file) - logging.debug("loaded new data from JOIN") + logging.debug(f"load new data{source_name}") + data, self.meta = self.download_data(file_name, meta_file) + logging.debug("loading finished") + # create slices and check for negative concentration. + data = self._slice_prep(data) + self.data = self.check_for_negative_concentrations(data) - def download_data(self, file_name, meta_file): + def download_data(self, file_name, meta_file) -> [xr.DataArray, pd.DataFrame]: """ - Download data from join, create slices and check for negative concentration. - - Handle sequence of required operation on new data downloads. First, download data using class method - download_data_from_join. Second, slice data using _slice_prep and lastly check for negative concentrations in - data with check_for_negative_concentrations. Finally, data is stored in instance attribute data. + Download data and meta. :param file_name: name of file to save data to (containing full path) :param meta_file: name of the meta data file (also containing full path) """ - data, self.meta = self.download_data_from_join(file_name, meta_file) - data = self._slice_prep(data) - self.data = self.check_for_negative_concentrations(data) + raise NotImplementedError def check_station_meta(self): """ - Search for the entries in meta data and compare the value with the requested values. + Placeholder function to implement some additional station meta data check if desired. - Will raise a FileNotFoundError if the values mismatch. + Ideally, this method should raise a FileNotFoundError if a value mismatch to load fresh data from a source. If + this method is not required for your application just inherit and add the `pass` command inside the method. The + NotImplementedError is more a reminder that you could use it. """ - check_dict = {"station_type": self.station_type, "network_name": self.network} - for (k, v) in check_dict.items(): - if self.meta.at[k, self.station[0]] != v: - logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " - f"grapping from web.") - raise FileNotFoundError - - def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: - """ - Download data from TOAR database using the JOIN interface. - - Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally - stored locally using given names for file and meta file. - - :param file_name: name of file to save data to (containing full path) - :param meta_file: name of the meta data file (also containing full path) - - :return: downloaded data and its meta data - """ - df_all = {} - df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, - station_type=self.station_type, network_name=self.network, sampling=self.sampling) - df_all[self.station[0]] = df - # convert df_all to xarray - xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} - xarr = xr.Dataset(xarr).to_array(dim='Stations') - if self.kwargs.get('store_data_locally', True): - # save locally as nc/csv file - xarr.to_netcdf(path=file_name) - meta.to_csv(meta_file) - return xarr, meta + raise NotImplementedError def _set_file_name(self): all_vars = sorted(self.statistics_per_var.keys()) @@ -178,8 +143,8 @@ class DataPrep(object): def __repr__(self): """Represent class attributes.""" - return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ - f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})" + return f"AbstractDataPrep(path='{self.path}', station={self.station}, variables={self.variables}, " \ + f"**{self.kwargs})" def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, **kwargs): @@ -589,5 +554,5 @@ class DataPrep(object): if __name__ == "__main__": - dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) print(dp) diff --git a/mlair/data_handling/data_preparation_join.py b/mlair/data_handling/data_preparation_join.py new file mode 100644 index 0000000000000000000000000000000000000000..86c7dee055c8258069307567b28ffcd113e13477 --- /dev/null +++ b/mlair/data_handling/data_preparation_join.py @@ -0,0 +1,124 @@ +"""Data Preparation class to handle data processing for machine learning.""" + +__author__ = 'Felix Kleinert, Lukas Leufen' +__date__ = '2019-10-16' + +import datetime as dt +import inspect +import logging +from typing import Union, List + +import pandas as pd +import xarray as xr + +from src import helpers +from src.helpers import join +from src.data_handling.data_preparation import AbstractDataPrep + +# define a more general date type for type hinting +date = Union[dt.date, dt.datetime] +str_or_list = Union[str, List[str]] +number = Union[float, int] +num_or_list = Union[number, List[number]] +data_or_none = Union[xr.DataArray, None] + + +class DataPrepJoin(AbstractDataPrep): + """ + This class prepares data to be used in neural networks. + + The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep + instance will load data from TOAR database and store this data locally to use the next time. For the moment, there + is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable. + + After data loading, different data pre-processing steps can be executed to prepare the data for further + applications. Especially the following methods can be used for the pre-processing step: + + - interpolate: interpolate between data points by using xarray's interpolation method + - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \ + interval [0, 1] are not implemented yet. + - make window history: represent the history (time steps before) for training/ testing; X + - make labels: create target vector with given leading time steps for training/ testing; y + - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \ + Use this method after the creation of the window history and labels to clean up the data cube. + + To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA, + "Umweltbundesamt") and the variables to use. Further options can be set in the instance. + + * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable. + * `start`: define a start date for the data cube creation. Default: Use the first entry in time series + * `end`: set the end date for the data cube. Default: Use last date in time series. + * `store_data_locally`: store recently downloaded data on local disk. Default: True + * set further parameters for xarray's interpolation methods to modify the interpolation scheme + + """ + + def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], network: str = None, + station_type: str = None, **kwargs): + self.network = network + self.station_type = station_type + params = helpers.remove_items(inspect.getfullargspec(AbstractDataPrep.__init__).args, "self") + kwargs = {**{k: v for k, v in locals().items() if k in params and v is not None}, **kwargs} + super().__init__(**kwargs) + + def download_data(self, file_name, meta_file): + """ + Download data and meta from join. + + :param file_name: name of file to save data to (containing full path) + :param meta_file: name of the meta data file (also containing full path) + """ + data, meta = self.download_data_from_join(file_name, meta_file) + return data, meta + + def check_station_meta(self): + """ + Search for the entries in meta data and compare the value with the requested values. + + Will raise a FileNotFoundError if the values mismatch. + """ + if self.station_type is not None: + check_dict = {"station_type": self.station_type, "network_name": self.network} + for (k, v) in check_dict.items(): + if v is None: + continue + if self.meta.at[k, self.station[0]] != v: + logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " + f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " + f"grapping from web.") + raise FileNotFoundError + + def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: + """ + Download data from TOAR database using the JOIN interface. + + Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally + stored locally using given names for file and meta file. + + :param file_name: name of file to save data to (containing full path) + :param meta_file: name of the meta data file (also containing full path) + + :return: downloaded data and its meta data + """ + df_all = {} + df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, + station_type=self.station_type, network_name=self.network, sampling=self.sampling) + df_all[self.station[0]] = df + # convert df_all to xarray + xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} + xarr = xr.Dataset(xarr).to_array(dim='Stations') + if self.kwargs.get('store_data_locally', True): + # save locally as nc/csv file + xarr.to_netcdf(path=file_name) + meta.to_csv(meta_file) + return xarr, meta + + def __repr__(self): + """Represent class attributes.""" + return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ + f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})" + + +if __name__ == "__main__": + dp = DataPrepJoin('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + print(dp) diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index c40847eddc054e3fc32923d5fee87d4458f76754..54a51d4270047c8cfb4a3092ac4fe959b9be9e3a 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -18,6 +18,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST +from mlair.data_handling import DataPrepJoin from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.model_class import MyLittleModel as VanillaModel @@ -228,7 +229,7 @@ class ExperimentSetup(RunEnvironment): train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, - batch_size=None, epochs=None): + batch_size=None, epochs=None, data_preparation=None): # create run framework super().__init__() @@ -296,6 +297,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") + self._set_param("data_preparation", data_preparation, default=DataPrepJoin) # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index f096d3b58290d8f6816962d302b5e7a10223c864..d390ecf05b2e3144b15edba0e30da7eb2b7e430c 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -13,7 +13,7 @@ import numpy as np import pandas as pd import xarray as xr -from mlair.data_handling import BootStraps, Distributor, DataGenerator, DataPrep +from mlair.data_handling import BootStraps, Distributor, DataGenerator, DataPrepJoin from mlair.helpers.datastore import NameNotFoundInDataStore from mlair.helpers import TimeTracking, statistics from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel @@ -358,7 +358,7 @@ class PostProcessing(RunEnvironment): return getter.get(self._sampling, None) @staticmethod - def _create_observation(data: DataPrep, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, + def _create_observation(data: DataPrepJoin, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: """ Create observation as ground truth from given data. @@ -402,7 +402,7 @@ class PostProcessing(RunEnvironment): ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction - def _create_persistence_forecast(self, data: DataPrep, persistence_prediction: xr.DataArray, mean: xr.DataArray, + def _create_persistence_forecast(self, data: DataPrepJoin, persistence_prediction: xr.DataArray, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: """ Create persistence forecast with given data. diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index c5955fc001dd27000e6e50d146eef129ad52f54a..c0d53aedabfb1c3aa6ab69219ec67f9e78c8b173 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -16,10 +16,10 @@ from mlair.configuration import path_config from mlair.helpers.join import EmptyQueryResult from mlair.run_modules.run_environment import RunEnvironment -DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] +DEFAULT_ARGS_LIST = ["data_path", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length", "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation", - "extreme_values", "extremes_on_right_tail_only"] + "extreme_values", "extremes_on_right_tail_only", "network", "data_preparation"] class PreProcessing(RunEnvironment): diff --git a/mlair/run_script.py b/mlair/run_script.py index 2d1f6aeee89c32da56b088a22b06a4b03a58674b..55e20e1e6914de27fc9d13893edacc504ab554f7 100644 --- a/mlair/run_script.py +++ b/mlair/run_script.py @@ -28,7 +28,8 @@ def run(stations=None, plot_list=None, model=None, batch_size=None, - epochs=None): + epochs=None, + data_preparation=None): params = inspect.getfullargspec(DefaultWorkflow).args kwargs = {k: v for k, v in locals().items() if k in params and v is not None} diff --git a/mlair/workflows/default_workflow.py b/mlair/workflows/default_workflow.py index c1a9c749865898c3eaa5493ee289602963025464..f42c0389d81f655fb0c8582a15e42acc853f757d 100644 --- a/mlair/workflows/default_workflow.py +++ b/mlair/workflows/default_workflow.py @@ -36,7 +36,8 @@ class DefaultWorkflow(Workflow): plot_list=None, model=None, batch_size=None, - epochs=None): + epochs=None, + data_preparation=None): super().__init__() # extract all given kwargs arguments @@ -80,7 +81,8 @@ class DefaultWorkflowHPC(Workflow): plot_list=None, model=None, batch_size=None, - epochs=None): + epochs=None, + data_preparation=None): super().__init__() # extract all given kwargs arguments diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index c47ec59af4280cfa72e27f1bce06348e5aa2876f..0d5f3a69b08fa646b66691e1265b9bfe05f114a5 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -9,13 +9,14 @@ import xarray as xr from mlair.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator from mlair.data_handling.data_generator import DataGenerator +from mlair.data_handling import DataPrepJoin @pytest.fixture def orig_generator(data_path): - return DataGenerator(data_path, 'AIRBASE', ['DEBW107', 'DEBW013'], - ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014, - statistics_per_var={"o3": "dma8eu", "temp": "maximum"}) + return DataGenerator(data_path, ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', + start=2010, end=2014, statistics_per_var={"o3": "dma8eu", "temp": "maximum"}, + data_preparation=DataPrepJoin) @pytest.fixture diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index a1fe2f667f33896d2bd2f4c4ad69713020dc7caf..d01133b58c37567f557543e7a4663717d15d71c7 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -7,6 +7,7 @@ import pytest from mlair.data_handling.data_distributor import Distributor from mlair.data_handling.data_generator import DataGenerator +from mlair.data_handling import DataPrepJoin from test.test_modules.test_training import my_test_model @@ -14,14 +15,16 @@ class TestDistributor: @pytest.fixture def generator(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def generator_two_stations(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'], + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def model(self): diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 579b34f6b24d126bc10d28a255ed2f6f63662bdd..cb86d174e598a0a0a839bac9a3e43b8ade539ee1 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -7,29 +7,24 @@ import pytest import xarray as xr from mlair.data_handling.data_generator import DataGenerator -from mlair.data_handling.data_preparation import DataPrep +from mlair.data_handling import DataPrepJoin from mlair.helpers.join import EmptyQueryResult class TestDataGenerator: - # @pytest.fixture(autouse=True, scope='module') - # def teardown_module(self): - # yield - # if "data" in os.listdir(os.path.dirname(__file__)): - # shutil.rmtree(os.path.join(os.path.dirname(__file__), "data"), ignore_errors=True) - @pytest.fixture def gen(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', start=2010, end=2014) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', start=2010, end=2014, data_preparation=DataPrepJoin) @pytest.fixture def gen_with_transformation(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014, transformation={"scope": "data", "mean": "estimate"}, - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def gen_no_init(self): @@ -39,9 +34,9 @@ class TestDataGenerator: if not os.path.exists(path): os.makedirs(path) generator.stations = ["DEBW107", "DEBW013", "DEBW001"] - generator.network = "AIRBASE" generator.variables = ["temp", "o3"] generator.station_type = "background" + generator.DataPrep = DataPrepJoin generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} return generator @@ -50,8 +45,8 @@ class TestDataGenerator: tmp = np.nan for station in gen_no_init.stations: try: - data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, - station_type=gen_no_init.station_type, **gen_no_init.kwargs) + data_prep = DataPrepJoin(gen_no_init.data_path, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) tmp = data_prep.data.combine_first(tmp) except EmptyQueryResult: continue @@ -64,8 +59,8 @@ class TestDataGenerator: mean, std = None, None for station in gen_no_init.stations: try: - data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, - station_type=gen_no_init.station_type, **gen_no_init.kwargs) + data_prep = DataPrepJoin(gen_no_init.data_path, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) mean = data_prep.data.mean(axis=1).combine_first(mean) std = data_prep.data.std(axis=1).combine_first(std) except EmptyQueryResult: @@ -82,7 +77,6 @@ class TestDataGenerator: def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') - assert gen.network == 'AIRBASE' assert gen.stations == ['DEBW107'] assert gen.variables == ['o3', 'temp'] assert gen.station_type is None @@ -98,7 +92,7 @@ class TestDataGenerator: def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') - assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], " \ + assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', stations=['DEBW107'], " \ f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \ f"target_dim='variables', target_var='o3', **{{'start': 2010, 'end': 2014}})" \ .rstrip() @@ -222,13 +216,13 @@ class TestDataGenerator: if os.path.exists(file): os.remove(file) assert not os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrepJoin) t = os.stat(file).st_ctime assert os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrepJoin) assert os.stat(file).st_mtime == t os.remove(file) - assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrepJoin) assert os.stat(file).st_ctime > t def test_get_data_generator_transform(self, gen_with_transformation): diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 4106e4a75a0ecd295a5ec4d2ffcab6c98b7a3b04..ebd351b020ce8a5902cbe7ed201876ce610b8f6a 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -8,121 +8,52 @@ import pandas as pd import pytest import xarray as xr -from mlair.data_handling.data_preparation import DataPrep +from mlair.data_handling.data_preparation import AbstractDataPrep +from mlair.data_handling import DataPrepJoin as DataPrep from mlair.helpers.join import EmptyQueryResult -class TestDataPrep: - - @pytest.fixture - def data(self): - return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - station_type='background', test='testKWARGS', - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) +class TestAbstractDataPrep: @pytest.fixture def data_prep_no_init(self): - d = object.__new__(DataPrep) + d = object.__new__(AbstractDataPrep) d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') - d.network = 'UBA' d.station = ['DEBW107'] d.variables = ['o3', 'temp'] d.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - d.station_type = "background" d.sampling = "daily" - d.kwargs = None + d.kwargs = {} return d - def test_init(self, data): - assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') - assert data.network == 'AIRBASE' - assert data.station == ['DEBW107'] - assert data.variables == ['o3', 'temp'] - assert data.station_type == "background" - assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'} - assert not any([data.mean, data.std, data.history, data.label, data.observation]) - assert {'test': 'testKWARGS'}.items() <= data.kwargs.items() + @pytest.fixture + def data(self): + return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, network="AIRBASE").data - def test_init_no_stats(self): + @pytest.fixture + def data_prep(self, data_prep_no_init, data): + data_prep_no_init.mean = None + data_prep_no_init.std = None + data_prep_no_init.history = None + data_prep_no_init.label = None + data_prep_no_init.observation = None + data_prep_no_init.extremes_history = None + data_prep_no_init.extremes_label = None + data_prep_no_init.data = None + data_prep_no_init.meta = None + data_prep_no_init._transform_method = None + data_prep_no_init.data = data + return data_prep_no_init + + def test_all_placeholders(self, data_prep_no_init): + with pytest.raises(NotImplementedError): + data_prep_no_init.download_data("a", "b") with pytest.raises(NotImplementedError): - DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp']) - - def test_download_data(self, data_prep_no_init): - file_name = data_prep_no_init._set_file_name() - meta_file = data_prep_no_init._set_meta_file_name() - data_prep_no_init.kwargs = {"store_data_locally": False} - data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - data_prep_no_init.download_data(file_name, meta_file) - assert isinstance(data_prep_no_init.data, xr.DataArray) - - def test_download_data_from_join(self, data_prep_no_init): - file_name = data_prep_no_init._set_file_name() - meta_file = data_prep_no_init._set_meta_file_name() - data_prep_no_init.kwargs = {"store_data_locally": False} - data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - xarr, meta = data_prep_no_init.download_data_from_join(file_name, meta_file) - assert isinstance(xarr, xr.DataArray) - assert isinstance(meta, pd.DataFrame) - - def test_check_station_meta(self, caplog, data_prep_no_init): - caplog.set_level(logging.DEBUG) - file_name = data_prep_no_init._set_file_name() - meta_file = data_prep_no_init._set_meta_file_name() - data_prep_no_init.kwargs = {"store_data_locally": False} - data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - data_prep_no_init.download_data(file_name, meta_file) - assert data_prep_no_init.check_station_meta() is None - data_prep_no_init.station_type = "traffic" - with pytest.raises(FileNotFoundError) as e: data_prep_no_init.check_station_meta() - msg = "meta data does not agree with given request for station_type: traffic (requested) != background (local)" - assert caplog.record_tuples[-1][:-1] == ('root', 10) - assert msg in caplog.record_tuples[-1][-1] - - def test_load_data_overwrite_local_data(self, data_prep_no_init): - data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - file_path = data_prep_no_init._set_file_name() - meta_file_path = data_prep_no_init._set_meta_file_name() - os.remove(file_path) - os.remove(meta_file_path) - assert not os.path.exists(file_path) - assert not os.path.exists(meta_file_path) - data_prep_no_init.kwargs = {"overwrite_local_data": True} - data_prep_no_init.load_data() - assert os.path.exists(file_path) - assert os.path.exists(meta_file_path) - t = os.stat(file_path).st_ctime - tm = os.stat(meta_file_path).st_ctime - data_prep_no_init.load_data() - assert os.path.exists(file_path) - assert os.path.exists(meta_file_path) - assert os.stat(file_path).st_ctime > t - assert os.stat(meta_file_path).st_ctime > tm - assert isinstance(data_prep_no_init.data, xr.DataArray) - assert isinstance(data_prep_no_init.meta, pd.DataFrame) - - def test_load_data_keep_local_data(self, data_prep_no_init): - data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} - data_prep_no_init.station_type = None - data_prep_no_init.kwargs = {} - file_path = data_prep_no_init._set_file_name() - data_prep_no_init.load_data() - assert os.path.exists(file_path) - t = os.stat(file_path).st_ctime - data_prep_no_init.load_data() - assert os.path.exists(data_prep_no_init._set_file_name()) - assert os.stat(file_path).st_ctime == t - assert isinstance(data_prep_no_init.data, xr.DataArray) - assert isinstance(data_prep_no_init.meta, pd.DataFrame) - - def test_repr(self, data_prep_no_init): - path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') - assert data_prep_no_init.__repr__().rstrip() == f"Dataprep(path='{path}', network='UBA', " \ - f"station=['DEBW107'], variables=['o3', 'temp'], " \ - f"station_type=background, **None)".rstrip() def test_set_file_name_and_meta(self): - d = object.__new__(DataPrep) + d = object.__new__(AbstractDataPrep) d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data") d.station = 'TESTSTATION' d.variables = ['a', 'bc'] @@ -134,41 +65,41 @@ class TestDataPrep: @pytest.mark.parametrize('opts', [{'dim': 'datetime', 'method': 'nearest', 'limit': 10, 'use_coordinate': True}, {'dim': 'datetime', 'limit': 5}, {'dim': 'datetime'}]) - def test_interpolate(self, data, opts): - data_org = data.data - data.interpolate(**opts) + def test_interpolate(self, data_prep, opts): + data_org = data_prep.data + data_prep.interpolate(**opts) # set default params if empty opts["method"] = opts.get("method", 'linear') opts["limit"] = opts.get("limit", None) opts["use_coordinate"] = opts.get("use_coordinate", True) - assert xr.testing.assert_equal(data_org.interpolate_na(**opts), data.data) is None - - def test_transform_standardise(self, data): - assert data._transform_method is None - assert data.mean is None - assert data.std is None - data.transform('datetime') - assert data._transform_method == 'standardise' - assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None - assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, np.array([[1, 1]])) is None - assert isinstance(data.mean, xr.DataArray) - assert isinstance(data.std, xr.DataArray) - - def test_transform_standardise_apply(self, data): - assert data._transform_method is None - assert data.mean is None - assert data.std is None - data_mean_orig = data.data.mean('datetime').variable.values - data_std_orig = data.data.std('datetime').variable.values + assert xr.testing.assert_equal(data_org.interpolate_na(**opts), data_prep.data) is None + + def test_transform_standardise(self, data_prep): + assert data_prep._transform_method is None + assert data_prep.mean is None + assert data_prep.std is None + data_prep.transform('datetime') + assert data_prep._transform_method == 'standardise' + assert np.testing.assert_almost_equal(data_prep.data.mean('datetime').variable.values, np.array([[0, 0]])) is None + assert np.testing.assert_almost_equal(data_prep.data.std('datetime').variable.values, np.array([[1, 1]])) is None + assert isinstance(data_prep.mean, xr.DataArray) + assert isinstance(data_prep.std, xr.DataArray) + + def test_transform_standardise_apply(self, data_prep): + assert data_prep._transform_method is None + assert data_prep.mean is None + assert data_prep.std is None + data_mean_orig = data_prep.data.mean('datetime').variable.values + data_std_orig = data_prep.data.std('datetime').variable.values mean_external = np.array([20, 12]) std_external = np.array([15, 5]) mean = xr.DataArray(mean_external, coords={"variables": ['o3', 'temp']}, dims=["variables"]) std = xr.DataArray(std_external, coords={"variables": ['o3', 'temp']}, dims=["variables"]) - data.transform('datetime', mean=mean, std=std) - assert all(data.mean.values == mean_external) - assert all(data.std.values == std_external) - data_mean_transformed = data.data.mean('datetime').variable.values - data_std_transformed = data.data.std('datetime').variable.values + data_prep.transform('datetime', mean=mean, std=std) + assert all(data_prep.mean.values == mean_external) + assert all(data_prep.std.values == std_external) + data_mean_transformed = data_prep.data.mean('datetime').variable.values + data_std_transformed = data_prep.data.std('datetime').variable.values data_mean_expected = (data_mean_orig - mean_external) / std_external # mean scales as any other data data_std_expected = data_std_orig / std_external # std scales by given std assert np.testing.assert_almost_equal(data_mean_transformed, data_mean_expected) is None @@ -178,129 +109,129 @@ class TestDataPrep: (None, 3, 'standardise', 'mean, '), (19, None, 'centre', ''), (None, 2, 'centre', 'mean, '), (8, 2, 'centre', ''), (None, None, 'standardise', 'mean, std, ')]) - def test_check_inverse_transform_params(self, data, mean, std, method, msg): + def test_check_inverse_transform_params(self, data_prep, mean, std, method, msg): if len(msg) > 0: with pytest.raises(AttributeError) as e: - data.check_inverse_transform_params(mean, std, method) + data_prep.check_inverse_transform_params(mean, std, method) assert msg in e.value.args[0] else: - assert data.check_inverse_transform_params(mean, std, method) is None - - def test_transform_centre(self, data): - assert data._transform_method is None - assert data.mean is None - assert data.std is None - data_std_orig = data.data.std('datetime').variable.values - data.transform('datetime', 'centre') - assert data._transform_method == 'centre' - assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None - assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, data_std_orig) is None - assert data.std is None - - def test_transform_centre_apply(self, data): - assert data._transform_method is None - assert data.mean is None - assert data.std is None - data_mean_orig = data.data.mean('datetime').variable.values - data_std_orig = data.data.std('datetime').variable.values + assert data_prep.check_inverse_transform_params(mean, std, method) is None + + def test_transform_centre(self, data_prep): + assert data_prep._transform_method is None + assert data_prep.mean is None + assert data_prep.std is None + data_std_orig = data_prep.data.std('datetime').variable.values + data_prep.transform('datetime', 'centre') + assert data_prep._transform_method == 'centre' + assert np.testing.assert_almost_equal(data_prep.data.mean('datetime').variable.values, np.array([[0, 0]])) is None + assert np.testing.assert_almost_equal(data_prep.data.std('datetime').variable.values, data_std_orig) is None + assert data_prep.std is None + + def test_transform_centre_apply(self, data_prep): + assert data_prep._transform_method is None + assert data_prep.mean is None + assert data_prep.std is None + data_mean_orig = data_prep.data.mean('datetime').variable.values + data_std_orig = data_prep.data.std('datetime').variable.values mean_external = np.array([20, 12]) mean = xr.DataArray(mean_external, coords={"variables": ['o3', 'temp']}, dims=["variables"]) - data.transform('datetime', 'centre', mean=mean) - assert all(data.mean.values == mean_external) - assert data.std is None - data_mean_transformed = data.data.mean('datetime').variable.values - data_std_transformed = data.data.std('datetime').variable.values + data_prep.transform('datetime', 'centre', mean=mean) + assert all(data_prep.mean.values == mean_external) + assert data_prep.std is None + data_mean_transformed = data_prep.data.mean('datetime').variable.values + data_std_transformed = data_prep.data.std('datetime').variable.values data_mean_expected = (data_mean_orig - mean_external) # mean scales as any other data assert np.testing.assert_almost_equal(data_mean_transformed, data_mean_expected) is None assert np.testing.assert_almost_equal(data_std_transformed, data_std_orig) is None @pytest.mark.parametrize('method', ['standardise', 'centre']) - def test_transform_inverse(self, data, method): - data_org = data.data - data.transform('datetime', method) - data.inverse_transform() - assert data._transform_method is None - assert data.mean is None - assert data.std is None - assert np.testing.assert_array_almost_equal(data_org, data.data) is None - data.transform('datetime', method) - data.transform('datetime', inverse=True) - assert data._transform_method is None - assert data.mean is None - assert data.std is None - assert np.testing.assert_array_almost_equal(data_org, data.data) is None + def test_transform_inverse(self, data_prep, method): + data_org = data_prep.data + data_prep.transform('datetime', method) + data_prep.inverse_transform() + assert data_prep._transform_method is None + assert data_prep.mean is None + assert data_prep.std is None + assert np.testing.assert_array_almost_equal(data_org, data_prep.data) is None + data_prep.transform('datetime', method) + data_prep.transform('datetime', inverse=True) + assert data_prep._transform_method is None + assert data_prep.mean is None + assert data_prep.std is None + assert np.testing.assert_array_almost_equal(data_org, data_prep.data) is None @pytest.mark.parametrize('method', ['normalise', 'unknownmethod']) - def test_transform_errors(self, data, method): + def test_transform_errors(self, data_prep, method): with pytest.raises(NotImplementedError): - data.transform('datetime', method) - data._transform_method = method + data_prep.transform('datetime', method) + data_prep._transform_method = method with pytest.raises(AssertionError) as e: - data.transform('datetime', method) + data_prep.transform('datetime', method) assert "Transform method is already set." in e.value.args[0] @pytest.mark.parametrize('method', ['normalise', 'unknownmethod']) - def test_transform_inverse_errors(self, data, method): + def test_transform_inverse_errors(self, data_prep, method): with pytest.raises(AssertionError) as e: - data.inverse_transform() + data_prep.inverse_transform() assert "Inverse transformation method is not set." in e.value.args[0] - data.mean = 1 - data.std = 1 - data._transform_method = method + data_prep.mean = 1 + data_prep.std = 1 + data_prep._transform_method = method with pytest.raises(NotImplementedError): - data.inverse_transform() - - def test_get_transformation_information(self, data): - assert (None, None, None) == data.get_transformation_information("o3") - mean_test = data.data.mean("datetime").sel(variables='o3').values - std_test = data.data.std("datetime").sel(variables='o3').values - data.transform('datetime') - mean, std, info = data.get_transformation_information("o3") + data_prep.inverse_transform() + + def test_get_transformation_information(self, data_prep): + assert (None, None, None) == data_prep.get_transformation_information("o3") + mean_test = data_prep.data.mean("datetime").sel(variables='o3').values + std_test = data_prep.data.std("datetime").sel(variables='o3').values + data_prep.transform('datetime') + mean, std, info = data_prep.get_transformation_information("o3") assert np.testing.assert_almost_equal(mean, mean_test) is None assert np.testing.assert_almost_equal(std, std_test) is None assert info == "standardise" - def test_remove_nan_no_hist_or_label(self, data): - assert not any([data.history, data.label, data.observation]) - data.remove_nan('datetime') - assert not any([data.history, data.label, data.observation]) - data.make_history_window('variables', 6, 'datetime') - assert data.history is not None - data.remove_nan('datetime') - assert data.history is None - data.make_labels('variables', 'o3', 'datetime', 2) - data.make_observation('variables', 'o3', 'datetime') - assert all(map(lambda x: x is not None, [data.label, data.observation])) - data.remove_nan('datetime') - assert not any([data.history, data.label, data.observation]) - - def test_remove_nan(self, data): - data.make_history_window('variables', -12, 'datetime') - data.make_labels('variables', 'o3', 'datetime', 3) - data.make_observation('variables', 'o3', 'datetime') - shape = data.history.shape - data.remove_nan('datetime') - assert data.history.isnull().sum() == 0 - assert itemgetter(0, 1, 3)(shape) == itemgetter(0, 1, 3)(data.history.shape) - assert shape[2] >= data.history.shape[2] - remaining_len = data.history.datetime.shape - assert remaining_len == data.label.datetime.shape - assert remaining_len == data.observation.datetime.shape - - def test_remove_nan_too_short(self, data): - data.kwargs["min_length"] = 4000 # actual length of series is 3940 - data.make_history_window('variables', -12, 'datetime') - data.make_labels('variables', 'o3', 'datetime', 3) - data.make_observation('variables', 'o3', 'datetime') - data.remove_nan('datetime') - assert not any([data.history, data.label, data.observation]) - - def test_create_index_array(self, data): - index_array = data.create_index_array('window', range(1, 4)) + def test_remove_nan_no_hist_or_label(self, data_prep): + assert not any([data_prep.history, data_prep.label, data_prep.observation]) + data_prep.remove_nan('datetime') + assert not any([data_prep.history, data_prep.label, data_prep.observation]) + data_prep.make_history_window('variables', 6, 'datetime') + assert data_prep.history is not None + data_prep.remove_nan('datetime') + assert data_prep.history is None + data_prep.make_labels('variables', 'o3', 'datetime', 2) + data_prep.make_observation('variables', 'o3', 'datetime') + assert all(map(lambda x: x is not None, [data_prep.label, data_prep.observation])) + data_prep.remove_nan('datetime') + assert not any([data_prep.history, data_prep.label, data_prep.observation]) + + def test_remove_nan(self, data_prep): + data_prep.make_history_window('variables', -12, 'datetime') + data_prep.make_labels('variables', 'o3', 'datetime', 3) + data_prep.make_observation('variables', 'o3', 'datetime') + shape = data_prep.history.shape + data_prep.remove_nan('datetime') + assert data_prep.history.isnull().sum() == 0 + assert itemgetter(0, 1, 3)(shape) == itemgetter(0, 1, 3)(data_prep.history.shape) + assert shape[2] >= data_prep.history.shape[2] + remaining_len = data_prep.history.datetime.shape + assert remaining_len == data_prep.label.datetime.shape + assert remaining_len == data_prep.observation.datetime.shape + + def test_remove_nan_too_short(self, data_prep): + data_prep.kwargs["min_length"] = 4000 # actual length of series is 3940 + data_prep.make_history_window('variables', -12, 'datetime') + data_prep.make_labels('variables', 'o3', 'datetime', 3) + data_prep.make_observation('variables', 'o3', 'datetime') + data_prep.remove_nan('datetime') + assert not any([data_prep.history, data_prep.label, data_prep.observation]) + + def test_create_index_array(self, data_prep): + index_array = data_prep.create_index_array('window', range(1, 4)) assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None assert index_array.name == 'window' assert index_array.coords.dims == ('window',) - index_array = data.create_index_array('window', range(0, 1)) + index_array = data_prep.create_index_array('window', range(0, 1)) assert np.testing.assert_array_equal(index_array.data, [0]) is None assert index_array.name == 'window' assert index_array.coords.dims == ('window',) @@ -319,108 +250,103 @@ class TestDataPrep: orig_slice = orig.sel(slice).data.flatten() return window, orig_slice - def test_shift(self, data): - res = data.shift('datetime', 4) - window, orig = self.extract_window_data(res, data.data, 4) + def test_shift(self, data_prep): + res = data_prep.shift('datetime', 4) + window, orig = self.extract_window_data(res, data_prep.data, 4) assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables') - assert list(res.data.shape) == [4, *data.data.shape] + assert list(res.data.shape) == [4, *data_prep.data.shape] assert np.testing.assert_array_equal(orig, window) is None - res = data.shift('datetime', -3) - window, orig = self.extract_window_data(res, data.data, -3) - assert list(res.data.shape) == [4, *data.data.shape] + res = data_prep.shift('datetime', -3) + window, orig = self.extract_window_data(res, data_prep.data, -3) + assert list(res.data.shape) == [4, *data_prep.data.shape] assert np.testing.assert_array_equal(orig, window) is None - res = data.shift('datetime', 0) - window, orig = self.extract_window_data(res, data.data, 0) - assert list(res.data.shape) == [1, *data.data.shape] + res = data_prep.shift('datetime', 0) + window, orig = self.extract_window_data(res, data_prep.data, 0) + assert list(res.data.shape) == [1, *data_prep.data.shape] assert np.testing.assert_array_equal(orig, window) is None - def test_make_history_window(self, data): - assert data.history is None - data.make_history_window("variables", 5, "datetime") - assert data.history is not None - save_history = data.history - data.make_history_window("variables", -5, "datetime") - assert np.testing.assert_array_equal(data.history, save_history) is None - - def test_make_labels(self, data): - assert data.label is None - data.make_labels('variables', 'o3', 'datetime', 3) - assert data.label.variables.data == 'o3' - assert list(data.label.shape) == [3, *data.data.shape[:2]] - save_label = data.label.copy() - data.make_labels('variables', 'o3', 'datetime', -3) - assert np.testing.assert_array_equal(data.label, save_label) is None - - def test_make_labels_multiple(self, data): - assert data.label is None - data.make_labels("variables", ["o3", "temp"], "datetime", 4) - assert all(data.label.variables.data == ["o3", "temp"]) - assert list(data.label.shape) == [4, *data.data.shape[:2], 2] - - def test_make_observation(self, data): - assert data.observation is None - data.make_observation("variables", "o3", "datetime") - assert data.observation.variables.data == "o3" - assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0]] - - def test_make_observation_multiple(self, data): - assert data.observation is None - data.make_observation("variables", ["o3", "temp"], "datetime") - assert all(data.observation.variables.data == ["o3", "temp"]) - assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0], 2] - - def test_slice(self, data): - res = data._slice(data.data, dt.date(1997, 1, 1), dt.date(1997, 1, 10), 'datetime') - assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape) + def test_make_history_window(self, data_prep): + assert data_prep.history is None + data_prep.make_history_window("variables", 5, "datetime") + assert data_prep.history is not None + save_history = data_prep.history + data_prep.make_history_window("variables", -5, "datetime") + assert np.testing.assert_array_equal(data_prep.history, save_history) is None + + def test_make_labels(self, data_prep): + assert data_prep.label is None + data_prep.make_labels('variables', 'o3', 'datetime', 3) + assert data_prep.label.variables.data == 'o3' + assert list(data_prep.label.shape) == [3, *data_prep.data.shape[:2]] + save_label = data_prep.label.copy() + data_prep.make_labels('variables', 'o3', 'datetime', -3) + assert np.testing.assert_array_equal(data_prep.label, save_label) is None + + def test_make_labels_multiple(self, data_prep): + assert data_prep.label is None + data_prep.make_labels("variables", ["o3", "temp"], "datetime", 4) + assert all(data_prep.label.variables.data == ["o3", "temp"]) + assert list(data_prep.label.shape) == [4, *data_prep.data.shape[:2], 2] + + def test_make_observation(self, data_prep): + assert data_prep.observation is None + data_prep.make_observation("variables", "o3", "datetime") + assert data_prep.observation.variables.data == "o3" + assert list(data_prep.observation.shape) == [1, 1, data_prep.data.datetime.shape[0]] + + def test_make_observation_multiple(self, data_prep): + assert data_prep.observation is None + data_prep.make_observation("variables", ["o3", "temp"], "datetime") + assert all(data_prep.observation.variables.data == ["o3", "temp"]) + assert list(data_prep.observation.shape) == [1, 1, data_prep.data.datetime.shape[0], 2] + + def test_slice(self, data_prep): + res = data_prep._slice(data_prep.data, dt.date(1997, 1, 1), dt.date(1997, 1, 10), 'datetime') + assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data_prep.data.shape) assert res.shape[1] == 10 - def test_slice_prep(self, data): - res = data._slice_prep(data.data) - assert res.shape == data.data.shape - data.kwargs['start'] = res.coords['datetime'][0].values - data.kwargs['end'] = res.coords['datetime'][9].values - res = data._slice_prep(data.data) - assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape) + def test_slice_prep(self, data_prep): + res = data_prep._slice_prep(data_prep.data) + assert res.shape == data_prep.data.shape + data_prep.kwargs['start'] = res.coords['datetime'][0].values + data_prep.kwargs['end'] = res.coords['datetime'][9].values + res = data_prep._slice_prep(data_prep.data) + assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data_prep.data.shape) assert res.shape[1] == 10 - def test_check_for_neg_concentrations(self, data): - res = data.check_for_negative_concentrations(data.data) + def test_check_for_neg_concentrations(self, data_prep): + res = data_prep.check_for_negative_concentrations(data_prep.data) assert res.sel({'variables': 'o3'}).min() >= 0 - res = data.check_for_negative_concentrations(data.data, minimum=2) + res = data_prep.check_for_negative_concentrations(data_prep.data, minimum=2) assert res.sel({'variables': 'o3'}).min() >= 2 - def test_check_station(self, data): - with pytest.raises(EmptyQueryResult): - data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'], - station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) - - def test_get_transposed_history(self, data): - data.make_history_window("variables", 3, "datetime") - transposed = data.get_transposed_history() + def test_get_transposed_history(self, data_prep): + data_prep.make_history_window("variables", 3, "datetime") + transposed = data_prep.get_transposed_history() assert transposed.coords.dims == ("datetime", "window", "Stations", "variables") - def test_get_transposed_label(self, data): - data.make_labels("variables", "o3", "datetime", 2) - transposed = data.get_transposed_label() + def test_get_transposed_label(self, data_prep): + data_prep.make_labels("variables", "o3", "datetime", 2) + transposed = data_prep.get_transposed_label() assert transposed.coords.dims == ("datetime", "window") - def test_multiply_extremes(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.make_labels("variables", "o3", "datetime", 2) - orig = data.label - data.multiply_extremes(1) - upsampled = data.extremes_label + def test_multiply_extremes(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.make_labels("variables", "o3", "datetime", 2) + orig = data_prep.label + data_prep.multiply_extremes(1) + upsampled = data_prep.extremes_label assert (upsampled > 1).sum() == (orig > 1).sum() assert (upsampled < -1).sum() == (orig < -1).sum() - def test_multiply_extremes_from_list(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.make_labels("variables", "o3", "datetime", 2) - orig = data.label - data.multiply_extremes([1, 1.5, 2, 3]) - upsampled = data.extremes_label + def test_multiply_extremes_from_list(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.make_labels("variables", "o3", "datetime", 2) + orig = data_prep.label + data_prep.multiply_extremes([1, 1.5, 2, 3]) + upsampled = data_prep.extremes_label def f(d, op, n): return op(d, n).any(dim="window").sum() @@ -429,22 +355,22 @@ class TestDataPrep: assert f(upsampled, lt, -1) == sum( [f(orig, lt, -1), f(orig, lt, -1.5), f(orig, lt, -2) * 2, f(orig, lt, -3) * 4]) - def test_multiply_extremes_wrong_extremes(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.make_labels("variables", "o3", "datetime", 2) + def test_multiply_extremes_wrong_extremes(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.make_labels("variables", "o3", "datetime", 2) with pytest.raises(TypeError) as e: - data.multiply_extremes([1, "1.5", 2]) + data_prep.multiply_extremes([1, "1.5", 2]) assert "Elements of list extreme_values have to be (<class 'float'>, <class 'int'>), but at least element 1.5" \ " is type <class 'str'>" in e.value.args[0] - def test_multiply_extremes_right_tail(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.make_labels("variables", "o3", "datetime", 2) - orig = data.label - data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) - upsampled = data.extremes_label + def test_multiply_extremes_right_tail(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.make_labels("variables", "o3", "datetime", 2) + orig = data_prep.label + data_prep.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + upsampled = data_prep.extremes_label def f(d, op, n): return op(d, n).any(dim="window").sum() @@ -453,39 +379,156 @@ class TestDataPrep: assert upsampled.shape[2] == sum([f(orig, gt, 1), f(orig, gt, 2)]) assert f(upsampled, lt, -1) == 0 - def test_multiply_extremes_none_label(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.label = None - assert data.multiply_extremes([1], extremes_on_right_tail_only=False) is None - - def test_multiply_extremes_none_history(self, data): - data.transform("datetime") - data.history = None - data.make_labels("variables", "o3", "datetime", 2) - assert data.multiply_extremes([1], extremes_on_right_tail_only=False) is None - - def test_multiply_extremes_none_label_history(self, data): - data.history = None - data.label = None - assert data.multiply_extremes([1], extremes_on_right_tail_only=False) is None - - def test_get_extremes_history(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.make_labels("variables", "o3", "datetime", 2) - data.make_observation("variables", "o3", "datetime") - data.remove_nan("datetime") - data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) - assert (data.get_extremes_history() == - data.extremes_history.transpose("datetime", "window", "Stations", "variables")).all() - - def test_get_extremes_label(self, data): - data.transform("datetime") - data.make_history_window("variables", 3, "datetime") - data.make_labels("variables", "o3", "datetime", 2) - data.make_observation("variables", "o3", "datetime") - data.remove_nan("datetime") - data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) - assert (data.get_extremes_label() == - data.extremes_label.squeeze("Stations").transpose("datetime", "window")).all() + def test_multiply_extremes_none_label(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.label = None + assert data_prep.multiply_extremes([1], extremes_on_right_tail_only=False) is None + + def test_multiply_extremes_none_history(self, data_prep): + data_prep.transform("datetime") + data_prep.history = None + data_prep.make_labels("variables", "o3", "datetime", 2) + assert data_prep.multiply_extremes([1], extremes_on_right_tail_only=False) is None + + def test_multiply_extremes_none_label_history(self, data_prep): + data_prep.history = None + data_prep.label = None + assert data_prep.multiply_extremes([1], extremes_on_right_tail_only=False) is None + + def test_get_extremes_history(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.make_labels("variables", "o3", "datetime", 2) + data_prep.make_observation("variables", "o3", "datetime") + data_prep.remove_nan("datetime") + data_prep.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data_prep.get_extremes_history() == + data_prep.extremes_history.transpose("datetime", "window", "Stations", "variables")).all() + + def test_get_extremes_label(self, data_prep): + data_prep.transform("datetime") + data_prep.make_history_window("variables", 3, "datetime") + data_prep.make_labels("variables", "o3", "datetime", 2) + data_prep.make_observation("variables", "o3", "datetime") + data_prep.remove_nan("datetime") + data_prep.multiply_extremes([1, 2], extremes_on_right_tail_only=True) + assert (data_prep.get_extremes_label() == + data_prep.extremes_label.squeeze("Stations").transpose("datetime", "window")).all() + + +class TestDataPrepJoin: + + @pytest.fixture + def data(self): + return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + station_type='background', test='testKWARGS', network="AIRBASE", + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + + @pytest.fixture + def data_prep_no_init(self): + d = object.__new__(DataPrep) + d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + d.network = 'UBA' + d.station = ['DEBW107'] + d.variables = ['o3', 'temp'] + d.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + d.station_type = "background" + d.sampling = "daily" + d.kwargs = None + return d + + def test_init(self, data): + assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + assert data.network == 'AIRBASE' + assert data.station == ['DEBW107'] + assert data.variables == ['o3', 'temp'] + assert data.station_type == "background" + assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'} + assert not any([data.mean, data.std, data.history, data.label, data.observation]) + assert {'test': 'testKWARGS'}.items() <= data.kwargs.items() + + def test_init_no_stats(self): + with pytest.raises(NotImplementedError): + DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp']) + + def test_download_data(self, data_prep_no_init): + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data, meta = data_prep_no_init.download_data(file_name, meta_file) + assert isinstance(data, xr.DataArray) + assert isinstance(meta, pd.DataFrame) + + def test_download_data_from_join(self, data_prep_no_init): + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + xarr, meta = data_prep_no_init.download_data_from_join(file_name, meta_file) + assert isinstance(xarr, xr.DataArray) + assert isinstance(meta, pd.DataFrame) + + def test_check_station_meta(self, caplog, data_prep_no_init): + caplog.set_level(logging.DEBUG) + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + _, meta = data_prep_no_init.download_data(file_name, meta_file) + data_prep_no_init.meta = meta + assert data_prep_no_init.check_station_meta() is None + data_prep_no_init.station_type = "traffic" + with pytest.raises(FileNotFoundError) as e: + data_prep_no_init.check_station_meta() + msg = "meta data does not agree with given request for station_type: traffic (requested) != background (local)" + assert caplog.record_tuples[-1][:-1] == ('root', 10) + assert msg in caplog.record_tuples[-1][-1] + + def test_load_data_overwrite_local_data(self, data_prep_no_init): + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + file_path = data_prep_no_init._set_file_name() + meta_file_path = data_prep_no_init._set_meta_file_name() + os.remove(file_path) if os.path.exists(file_path) else None + os.remove(meta_file_path) if os.path.exists(meta_file_path) else None + assert not os.path.exists(file_path) + assert not os.path.exists(meta_file_path) + data_prep_no_init.kwargs = {"overwrite_local_data": True} + data_prep_no_init.load_data() + assert os.path.exists(file_path) + assert os.path.exists(meta_file_path) + t = os.stat(file_path).st_ctime + tm = os.stat(meta_file_path).st_ctime + data_prep_no_init.load_data() + assert os.path.exists(file_path) + assert os.path.exists(meta_file_path) + assert os.stat(file_path).st_ctime > t + assert os.stat(meta_file_path).st_ctime > tm + assert isinstance(data_prep_no_init.data, xr.DataArray) + assert isinstance(data_prep_no_init.meta, pd.DataFrame) + + def test_load_data_keep_local_data(self, data_prep_no_init): + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.station_type = None + data_prep_no_init.kwargs = {} + file_path = data_prep_no_init._set_file_name() + data_prep_no_init.load_data() + assert os.path.exists(file_path) + t = os.stat(file_path).st_ctime + data_prep_no_init.load_data() + assert os.path.exists(data_prep_no_init._set_file_name()) + assert os.stat(file_path).st_ctime == t + assert isinstance(data_prep_no_init.data, xr.DataArray) + assert isinstance(data_prep_no_init.meta, pd.DataFrame) + + def test_repr(self, data_prep_no_init): + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + assert data_prep_no_init.__repr__().rstrip() == f"Dataprep(path='{path}', network='UBA', " \ + f"station=['DEBW107'], variables=['o3', 'temp'], " \ + f"station_type=background, **None)".rstrip() + + def test_check_station(self, data): + with pytest.raises(EmptyQueryResult): + data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'], + station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index b61e52d21758213c34b7830589f58e1edb53dc77..2b83d2549ea2f649091d2f16b67bf0d93789af52 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -2,6 +2,7 @@ import os import pytest +from mlair.data_handling import DataPrepJoin from mlair.data_handling.data_generator import DataGenerator from mlair.helpers.datastore import EmptyScope from mlair.model_modules.keras_extensions import CallbackHandler @@ -29,8 +30,9 @@ class TestModelSetup: @pytest.fixture def gen(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def setup_with_gen(self, setup, gen): diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 93f322f10820348ea3b431baccfc7404079c8fe8..a35e810c2d62ab746004442bffee51d85dc17ab2 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -2,6 +2,7 @@ import logging import pytest +from mlair.data_handling import DataPrepJoin from mlair.data_handling.data_generator import DataGenerator from mlair.helpers.datastore import NameNotFoundInScope from mlair.helpers import PyTestRegex @@ -27,7 +28,8 @@ class TestPreProcessing: @pytest.fixture def obj_with_exp_setup(self): ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background") + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", + data_preparation=DataPrepJoin) pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index 292bf7e76deb587b7a64e31e93a67cb143d8e540..1f218db5f46b256a4772bdd0521e248d84c54da2 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -9,6 +9,7 @@ import mock import pytest from keras.callbacks import History +from mlair.data_handling import DataPrepJoin from mlair.data_handling.data_distributor import Distributor from mlair.data_handling.data_generator import DataGenerator from mlair.helpers import PyTestRegex @@ -108,9 +109,9 @@ class TestTraining: @pytest.fixture def generator(self, path): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', - ['DEBW107'], ['o3', 'temp'], 'datetime', 'variables', - 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107'], ['o3', 'temp'], 'datetime', + 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def model(self):