From 09c7a4188eb8cf3c5e1c0c562dc4f16d2f3bd957 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 1 Jul 2020 18:33:22 +0200 Subject: [PATCH] DataPreparation is now adjustable in all run scripts and workflows by using the data_preparation parameter. Default is DataPrepJoin. --- src/data_handling/data_generator.py | 26 +++++++------- src/data_handling/data_preparation_join.py | 4 ++- src/run.py | 3 +- src/run_modules/experiment_setup.py | 4 ++- src/run_modules/pre_processing.py | 4 +-- src/workflows/default_workflow.py | 6 ++-- test/test_data_handling/test_bootstraps.py | 7 ++-- .../test_data_distributor.py | 11 +++--- .../test_data_handling/test_data_generator.py | 36 ++++++++----------- .../test_data_preparation.py | 8 ++--- 10 files changed, 57 insertions(+), 52 deletions(-) diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 7b83b56f..8e14d019 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -13,7 +13,7 @@ import keras import xarray as xr from src import helpers -from src.data_handling.data_preparation_join import DataPrepJoin +from src.data_handling.data_preparation import AbstractDataPrep from src.helpers.join import EmptyQueryResult number = Union[float, int] @@ -57,15 +57,15 @@ class DataGenerator(keras.utils.Sequence): This class can also be used with keras' fit_generator and predict_generator. Individual stations are the iterables. """ - def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], + def __init__(self, data_path: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs): + window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, + data_preparation=None, **kwargs): """ Set up data generator. :param data_path: path to data - :param network: the observational network, the data should come from :param stations: list with all stations to include :param variables: list with all used variables :param interpolate_dim: dimension along which interpolation is applied @@ -85,7 +85,6 @@ class DataGenerator(keras.utils.Sequence): self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): os.makedirs(self.data_path_tmp) - self.network = network self.stations = helpers.to_list(stations) self.variables = variables self.interpolate_dim = interpolate_dim @@ -97,12 +96,13 @@ class DataGenerator(keras.utils.Sequence): self.window_history_size = window_history_size self.window_lead_time = window_lead_time self.extreme_values = extreme_values + self.DataPrep = data_preparation if data_preparation is not None else AbstractDataPrep self.kwargs = kwargs self.transformation = self.setup_transformation(transformation) def __repr__(self): """Display all class attributes.""" - return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \ + return f"DataGenerator(path='{self.data_path}', stations={self.stations}, " \ f"variables={self.variables}, station_type={self.station_type}, " \ f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \ f"target_var='{self.target_var}', **{self.kwargs})" @@ -210,8 +210,8 @@ class DataGenerator(keras.utils.Sequence): std = None for station in self.stations: try: - data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) chunks = (1, 100, data.data.shape[2]) tmp.append(da.from_array(data.data.data, chunks=chunks)) except EmptyQueryResult: @@ -249,8 +249,8 @@ class DataGenerator(keras.utils.Sequence): std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) for station in self.stations: try: - data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) data.transform("datetime", method=method) mean = mean.combine_first(data.mean) std = std.combine_first(data.std) @@ -260,7 +260,7 @@ class DataGenerator(keras.utils.Sequence): return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, - save_local_tmp_storage: bool = True) -> DataPrepJoin: + save_local_tmp_storage: bool = True) -> AbstractDataPrep: """ Create DataPrep object and preprocess data for given key. @@ -288,8 +288,8 @@ class DataGenerator(keras.utils.Sequence): data = self._load_pickle_data(station, self.variables) except FileNotFoundError: logging.debug(f"load not pickle data for {station}") - data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) + data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type, + **self.kwargs) if self.transformation is not None: data.transform("datetime", **helpers.remove_items(self.transformation, "scope")) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) diff --git a/src/data_handling/data_preparation_join.py b/src/data_handling/data_preparation_join.py index 7655fbf6..86c7dee0 100644 --- a/src/data_handling/data_preparation_join.py +++ b/src/data_handling/data_preparation_join.py @@ -53,7 +53,7 @@ class DataPrepJoin(AbstractDataPrep): """ - def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str], + def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], network: str = None, station_type: str = None, **kwargs): self.network = network self.station_type = station_type @@ -80,6 +80,8 @@ class DataPrepJoin(AbstractDataPrep): if self.station_type is not None: check_dict = {"station_type": self.station_type, "network_name": self.network} for (k, v) in check_dict.items(): + if v is None: + continue if self.meta.at[k, self.station[0]] != v: logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " diff --git a/src/run.py b/src/run.py index 1494be0a..7e262dd7 100644 --- a/src/run.py +++ b/src/run.py @@ -28,7 +28,8 @@ def run(stations=None, plot_list=None, model=None, batch_size=None, - epochs=None): + epochs=None, + data_preparation=None): params = inspect.getfullargspec(DefaultWorkflow).args kwargs = {k: v for k, v in locals().items() if k in params and v is not None} diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index a93fe403..b460decd 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -18,6 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST +from src.data_handling import DataPrepJoin from src.run_modules.run_environment import RunEnvironment from src.model_modules.model_class import MyLittleModel as VanillaModel @@ -228,7 +229,7 @@ class ExperimentSetup(RunEnvironment): train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, - batch_size=None, epochs=None): + batch_size=None, epochs=None, data_preparation=None): # create run framework super().__init__() @@ -296,6 +297,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") + self._set_param("DataPrep", data_preparation, default=DataPrepJoin) # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index b4b36a20..6843ae2e 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -16,10 +16,10 @@ from src.configuration import path_config from src.helpers.join import EmptyQueryResult from src.run_modules.run_environment import RunEnvironment -DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] +DEFAULT_ARGS_LIST = ["data_path", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length", "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation", - "extreme_values", "extremes_on_right_tail_only"] + "extreme_values", "extremes_on_right_tail_only", "network", "DataPrep"] class PreProcessing(RunEnvironment): diff --git a/src/workflows/default_workflow.py b/src/workflows/default_workflow.py index 6a60c6ae..bbad7428 100644 --- a/src/workflows/default_workflow.py +++ b/src/workflows/default_workflow.py @@ -36,7 +36,8 @@ class DefaultWorkflow(Workflow): plot_list=None, model=None, batch_size=None, - epochs=None): + epochs=None, + data_preparation=None): super().__init__() # extract all given kwargs arguments @@ -80,7 +81,8 @@ class DefaultWorkflowHPC(Workflow): plot_list=None, model=None, batch_size=None, - epochs=None): + epochs=None, + data_preparation=None): super().__init__() # extract all given kwargs arguments diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index 3d32a090..839b0220 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -9,13 +9,14 @@ import xarray as xr from src.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator from src.data_handling.data_generator import DataGenerator +from src.data_handling import DataPrepJoin @pytest.fixture def orig_generator(data_path): - return DataGenerator(data_path, 'AIRBASE', ['DEBW107', 'DEBW013'], - ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014, - statistics_per_var={"o3": "dma8eu", "temp": "maximum"}) + return DataGenerator(data_path, ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', + start=2010, end=2014, statistics_per_var={"o3": "dma8eu", "temp": "maximum"}, + data_preparation=DataPrepJoin) @pytest.fixture diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 9e2242fe..43c61be2 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -7,6 +7,7 @@ import pytest from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator +from src.data_handling import DataPrepJoin from test.test_modules.test_training import my_test_model @@ -14,14 +15,16 @@ class TestDistributor: @pytest.fixture def generator(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def generator_two_stations(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'], + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def model(self): diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 754728ba..3144bde3 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -7,29 +7,24 @@ import pytest import xarray as xr from src.data_handling.data_generator import DataGenerator -from src.data_handling.data_preparation import DataPrep +from src.data_handling import DataPrepJoin from src.helpers.join import EmptyQueryResult class TestDataGenerator: - # @pytest.fixture(autouse=True, scope='module') - # def teardown_module(self): - # yield - # if "data" in os.listdir(os.path.dirname(__file__)): - # shutil.rmtree(os.path.join(os.path.dirname(__file__), "data"), ignore_errors=True) - @pytest.fixture def gen(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - 'datetime', 'variables', 'o3', start=2010, end=2014) + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', start=2010, end=2014, data_preparation=DataPrepJoin) @pytest.fixture def gen_with_transformation(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014, transformation={"scope": "data", "mean": "estimate"}, - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_preparation=DataPrepJoin) @pytest.fixture def gen_no_init(self): @@ -39,9 +34,9 @@ class TestDataGenerator: if not os.path.exists(path): os.makedirs(path) generator.stations = ["DEBW107", "DEBW013", "DEBW001"] - generator.network = "AIRBASE" generator.variables = ["temp", "o3"] generator.station_type = "background" + generator.DataPrep = DataPrepJoin generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} return generator @@ -50,8 +45,8 @@ class TestDataGenerator: tmp = np.nan for station in gen_no_init.stations: try: - data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, - station_type=gen_no_init.station_type, **gen_no_init.kwargs) + data_prep = DataPrepJoin(gen_no_init.data_path, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) tmp = data_prep.data.combine_first(tmp) except EmptyQueryResult: continue @@ -64,8 +59,8 @@ class TestDataGenerator: mean, std = None, None for station in gen_no_init.stations: try: - data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables, - station_type=gen_no_init.station_type, **gen_no_init.kwargs) + data_prep = DataPrepJoin(gen_no_init.data_path, station, gen_no_init.variables, + station_type=gen_no_init.station_type, **gen_no_init.kwargs) mean = data_prep.data.mean(axis=1).combine_first(mean) std = data_prep.data.std(axis=1).combine_first(std) except EmptyQueryResult: @@ -82,7 +77,6 @@ class TestDataGenerator: def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') - assert gen.network == 'AIRBASE' assert gen.stations == ['DEBW107'] assert gen.variables == ['o3', 'temp'] assert gen.station_type is None @@ -98,7 +92,7 @@ class TestDataGenerator: def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') - assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], " \ + assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', stations=['DEBW107'], " \ f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \ f"target_dim='variables', target_var='o3', **{{'start': 2010, 'end': 2014}})" \ .rstrip() @@ -222,13 +216,13 @@ class TestDataGenerator: if os.path.exists(file): os.remove(file) assert not os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrepJoin) t = os.stat(file).st_ctime assert os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrepJoin) assert os.stat(file).st_mtime == t os.remove(file) - assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrepJoin) assert os.stat(file).st_ctime > t def test_get_data_generator_transform(self, gen_with_transformation): diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 00efa1ac..3af8a045 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -28,8 +28,8 @@ class TestAbstractDataPrep: @pytest.fixture def data(self): - return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}).data + return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, network="AIRBASE").data @pytest.fixture def data_prep(self, data_prep_no_init, data): @@ -421,8 +421,8 @@ class TestDataPrepJoin: @pytest.fixture def data(self): - return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], - station_type='background', test='testKWARGS', + return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'], + station_type='background', test='testKWARGS', network="AIRBASE", statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) @pytest.fixture -- GitLab