From 09c7a4188eb8cf3c5e1c0c562dc4f16d2f3bd957 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 1 Jul 2020 18:33:22 +0200
Subject: [PATCH] DataPreparation is now adjustable in all run scripts and
 workflows by using the data_preparation parameter. Default is DataPrepJoin.

---
 src/data_handling/data_generator.py           | 26 +++++++-------
 src/data_handling/data_preparation_join.py    |  4 ++-
 src/run.py                                    |  3 +-
 src/run_modules/experiment_setup.py           |  4 ++-
 src/run_modules/pre_processing.py             |  4 +--
 src/workflows/default_workflow.py             |  6 ++--
 test/test_data_handling/test_bootstraps.py    |  7 ++--
 .../test_data_distributor.py                  | 11 +++---
 .../test_data_handling/test_data_generator.py | 36 ++++++++-----------
 .../test_data_preparation.py                  |  8 ++---
 10 files changed, 57 insertions(+), 52 deletions(-)

diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 7b83b56f..8e14d019 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -13,7 +13,7 @@ import keras
 import xarray as xr
 
 from src import helpers
-from src.data_handling.data_preparation_join import DataPrepJoin
+from src.data_handling.data_preparation import AbstractDataPrep
 from src.helpers.join import EmptyQueryResult
 
 number = Union[float, int]
@@ -57,15 +57,15 @@ class DataGenerator(keras.utils.Sequence):
     This class can also be used with keras' fit_generator and predict_generator. Individual stations are the iterables.
     """
 
-    def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
+    def __init__(self, data_path: str, stations: Union[str, List[str]], variables: List[str],
                  interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
                  interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
-                 window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs):
+                 window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None,
+                 data_preparation=None, **kwargs):
         """
         Set up data generator.
 
         :param data_path: path to data
-        :param network: the observational network, the data should come from
         :param stations: list with all stations to include
         :param variables: list with all used variables
         :param interpolate_dim: dimension along which interpolation is applied
@@ -85,7 +85,6 @@ class DataGenerator(keras.utils.Sequence):
         self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
         if not os.path.exists(self.data_path_tmp):
             os.makedirs(self.data_path_tmp)
-        self.network = network
         self.stations = helpers.to_list(stations)
         self.variables = variables
         self.interpolate_dim = interpolate_dim
@@ -97,12 +96,13 @@ class DataGenerator(keras.utils.Sequence):
         self.window_history_size = window_history_size
         self.window_lead_time = window_lead_time
         self.extreme_values = extreme_values
+        self.DataPrep = data_preparation if data_preparation is not None else AbstractDataPrep
         self.kwargs = kwargs
         self.transformation = self.setup_transformation(transformation)
 
     def __repr__(self):
         """Display all class attributes."""
-        return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \
+        return f"DataGenerator(path='{self.data_path}', stations={self.stations}, " \
                f"variables={self.variables}, station_type={self.station_type}, " \
                f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \
                f"target_var='{self.target_var}', **{self.kwargs})"
@@ -210,8 +210,8 @@ class DataGenerator(keras.utils.Sequence):
         std = None
         for station in self.stations:
             try:
-                data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type,
-                                    **self.kwargs)
+                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
+                                     **self.kwargs)
                 chunks = (1, 100, data.data.shape[2])
                 tmp.append(da.from_array(data.data.data, chunks=chunks))
             except EmptyQueryResult:
@@ -249,8 +249,8 @@ class DataGenerator(keras.utils.Sequence):
         std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"])
         for station in self.stations:
             try:
-                data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type,
-                                    **self.kwargs)
+                data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
+                                     **self.kwargs)
                 data.transform("datetime", method=method)
                 mean = mean.combine_first(data.mean)
                 std = std.combine_first(data.std)
@@ -260,7 +260,7 @@ class DataGenerator(keras.utils.Sequence):
         return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None
 
     def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True,
-                           save_local_tmp_storage: bool = True) -> DataPrepJoin:
+                           save_local_tmp_storage: bool = True) -> AbstractDataPrep:
         """
         Create DataPrep object and preprocess data for given key.
 
@@ -288,8 +288,8 @@ class DataGenerator(keras.utils.Sequence):
             data = self._load_pickle_data(station, self.variables)
         except FileNotFoundError:
             logging.debug(f"load not pickle data for {station}")
-            data = DataPrepJoin(self.data_path, self.network, station, self.variables, station_type=self.station_type,
-                                **self.kwargs)
+            data = self.DataPrep(self.data_path, station, self.variables, station_type=self.station_type,
+                                 **self.kwargs)
             if self.transformation is not None:
                 data.transform("datetime", **helpers.remove_items(self.transformation, "scope"))
             data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
diff --git a/src/data_handling/data_preparation_join.py b/src/data_handling/data_preparation_join.py
index 7655fbf6..86c7dee0 100644
--- a/src/data_handling/data_preparation_join.py
+++ b/src/data_handling/data_preparation_join.py
@@ -53,7 +53,7 @@ class DataPrepJoin(AbstractDataPrep):
 
     """
 
-    def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str],
+    def __init__(self, path: str, station: Union[str, List[str]], variables: List[str], network: str = None,
                  station_type: str = None, **kwargs):
         self.network = network
         self.station_type = station_type
@@ -80,6 +80,8 @@ class DataPrepJoin(AbstractDataPrep):
         if self.station_type is not None:
             check_dict = {"station_type": self.station_type, "network_name": self.network}
             for (k, v) in check_dict.items():
+                if v is None:
+                    continue
                 if self.meta.at[k, self.station[0]] != v:
                     logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
                                   f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new "
diff --git a/src/run.py b/src/run.py
index 1494be0a..7e262dd7 100644
--- a/src/run.py
+++ b/src/run.py
@@ -28,7 +28,8 @@ def run(stations=None,
         plot_list=None,
         model=None,
         batch_size=None,
-        epochs=None):
+        epochs=None,
+        data_preparation=None):
 
     params = inspect.getfullargspec(DefaultWorkflow).args
     kwargs = {k: v for k, v in locals().items() if k in params and v is not None}
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index a93fe403..b460decd 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -18,6 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST
+from src.data_handling import DataPrepJoin
 from src.run_modules.run_environment import RunEnvironment
 from src.model_modules.model_class import MyLittleModel as VanillaModel
 
@@ -228,7 +229,7 @@ class ExperimentSetup(RunEnvironment):
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
                  create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None,
-                 batch_size=None, epochs=None):
+                 batch_size=None, epochs=None, data_preparation=None):
 
         # create run framework
         super().__init__()
@@ -296,6 +297,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("sampling", sampling)
         self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
         self._set_param("transformation", None, scope="preprocessing")
+        self._set_param("DataPrep", data_preparation, default=DataPrepJoin)
 
         # target
         self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index b4b36a20..6843ae2e 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -16,10 +16,10 @@ from src.configuration import path_config
 from src.helpers.join import EmptyQueryResult
 from src.run_modules.run_environment import RunEnvironment
 
-DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
+DEFAULT_ARGS_LIST = ["data_path", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
 DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length",
                        "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation",
-                       "extreme_values", "extremes_on_right_tail_only"]
+                       "extreme_values", "extremes_on_right_tail_only", "network", "DataPrep"]
 
 
 class PreProcessing(RunEnvironment):
diff --git a/src/workflows/default_workflow.py b/src/workflows/default_workflow.py
index 6a60c6ae..bbad7428 100644
--- a/src/workflows/default_workflow.py
+++ b/src/workflows/default_workflow.py
@@ -36,7 +36,8 @@ class DefaultWorkflow(Workflow):
         plot_list=None,
         model=None,
         batch_size=None,
-        epochs=None):
+        epochs=None,
+        data_preparation=None):
         super().__init__()
 
         # extract all given kwargs arguments
@@ -80,7 +81,8 @@ class DefaultWorkflowHPC(Workflow):
         plot_list=None,
         model=None,
         batch_size=None,
-        epochs=None):
+        epochs=None,
+        data_preparation=None):
         super().__init__()
 
         # extract all given kwargs arguments
diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py
index 3d32a090..839b0220 100644
--- a/test/test_data_handling/test_bootstraps.py
+++ b/test/test_data_handling/test_bootstraps.py
@@ -9,13 +9,14 @@ import xarray as xr
 
 from src.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator
 from src.data_handling.data_generator import DataGenerator
+from src.data_handling import DataPrepJoin
 
 
 @pytest.fixture
 def orig_generator(data_path):
-    return DataGenerator(data_path, 'AIRBASE', ['DEBW107', 'DEBW013'],
-                         ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014,
-                         statistics_per_var={"o3": "dma8eu", "temp": "maximum"})
+    return DataGenerator(data_path, ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3',
+                         start=2010, end=2014, statistics_per_var={"o3": "dma8eu", "temp": "maximum"},
+                         data_preparation=DataPrepJoin)
 
 
 @pytest.fixture
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 9e2242fe..43c61be2 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -7,6 +7,7 @@ import pytest
 
 from src.data_handling.data_distributor import Distributor
 from src.data_handling.data_generator import DataGenerator
+from src.data_handling import DataPrepJoin
 from test.test_modules.test_training import my_test_model
 
 
@@ -14,14 +15,16 @@ class TestDistributor:
 
     @pytest.fixture
     def generator(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
-                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
+                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
+                             data_preparation=DataPrepJoin)
 
     @pytest.fixture
     def generator_two_stations(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), ['DEBW107', 'DEBW013'],
                              ['o3', 'temp'], 'datetime', 'variables', 'o3',
-                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
+                             data_preparation=DataPrepJoin)
 
     @pytest.fixture
     def model(self):
diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py
index 754728ba..3144bde3 100644
--- a/test/test_data_handling/test_data_generator.py
+++ b/test/test_data_handling/test_data_generator.py
@@ -7,29 +7,24 @@ import pytest
 import xarray as xr
 
 from src.data_handling.data_generator import DataGenerator
-from src.data_handling.data_preparation import DataPrep
+from src.data_handling import DataPrepJoin
 from src.helpers.join import EmptyQueryResult
 
 
 class TestDataGenerator:
 
-    # @pytest.fixture(autouse=True, scope='module')
-    # def teardown_module(self):
-    #     yield
-    #     if "data" in os.listdir(os.path.dirname(__file__)):
-    #         shutil.rmtree(os.path.join(os.path.dirname(__file__), "data"), ignore_errors=True)
-
     @pytest.fixture
     def gen(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
-                             'datetime', 'variables', 'o3', start=2010, end=2014)
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
+                             'datetime', 'variables', 'o3', start=2010, end=2014, data_preparation=DataPrepJoin)
 
     @pytest.fixture
     def gen_with_transformation(self):
-        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
+        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
                              'datetime', 'variables', 'o3', start=2010, end=2014,
                              transformation={"scope": "data", "mean": "estimate"},
-                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
+                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
+                             data_preparation=DataPrepJoin)
 
     @pytest.fixture
     def gen_no_init(self):
@@ -39,9 +34,9 @@ class TestDataGenerator:
         if not os.path.exists(path):
             os.makedirs(path)
         generator.stations = ["DEBW107", "DEBW013", "DEBW001"]
-        generator.network = "AIRBASE"
         generator.variables = ["temp", "o3"]
         generator.station_type = "background"
+        generator.DataPrep = DataPrepJoin
         generator.kwargs = {"start": 2010, "end": 2014, "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}}
         return generator
 
@@ -50,8 +45,8 @@ class TestDataGenerator:
         tmp = np.nan
         for station in gen_no_init.stations:
             try:
-                data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables,
-                                     station_type=gen_no_init.station_type, **gen_no_init.kwargs)
+                data_prep = DataPrepJoin(gen_no_init.data_path, station, gen_no_init.variables,
+                                         station_type=gen_no_init.station_type, **gen_no_init.kwargs)
                 tmp = data_prep.data.combine_first(tmp)
             except EmptyQueryResult:
                 continue
@@ -64,8 +59,8 @@ class TestDataGenerator:
         mean, std = None, None
         for station in gen_no_init.stations:
             try:
-                data_prep = DataPrep(gen_no_init.data_path, gen_no_init.network, station, gen_no_init.variables,
-                                     station_type=gen_no_init.station_type, **gen_no_init.kwargs)
+                data_prep = DataPrepJoin(gen_no_init.data_path, station, gen_no_init.variables,
+                                         station_type=gen_no_init.station_type, **gen_no_init.kwargs)
                 mean = data_prep.data.mean(axis=1).combine_first(mean)
                 std = data_prep.data.std(axis=1).combine_first(std)
             except EmptyQueryResult:
@@ -82,7 +77,6 @@ class TestDataGenerator:
 
     def test_init(self, gen):
         assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
-        assert gen.network == 'AIRBASE'
         assert gen.stations == ['DEBW107']
         assert gen.variables == ['o3', 'temp']
         assert gen.station_type is None
@@ -98,7 +92,7 @@ class TestDataGenerator:
 
     def test_repr(self, gen):
         path = os.path.join(os.path.dirname(__file__), 'data')
-        assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], " \
+        assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', stations=['DEBW107'], " \
                                           f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \
                                           f"target_dim='variables', target_var='o3', **{{'start': 2010, 'end': 2014}})" \
             .rstrip()
@@ -222,13 +216,13 @@ class TestDataGenerator:
         if os.path.exists(file):
             os.remove(file)
         assert not os.path.exists(file)
-        assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep)
+        assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrepJoin)
         t = os.stat(file).st_ctime
         assert os.path.exists(file)
-        assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
+        assert isinstance(gen.get_data_generator("DEBW107"), DataPrepJoin)
         assert os.stat(file).st_mtime == t
         os.remove(file)
-        assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
+        assert isinstance(gen.get_data_generator("DEBW107"), DataPrepJoin)
         assert os.stat(file).st_ctime > t
 
     def test_get_data_generator_transform(self, gen_with_transformation):
diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py
index 00efa1ac..3af8a045 100644
--- a/test/test_data_handling/test_data_preparation.py
+++ b/test/test_data_handling/test_data_preparation.py
@@ -28,8 +28,8 @@ class TestAbstractDataPrep:
 
     @pytest.fixture
     def data(self):
-        return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
-                        statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}).data
+        return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
+                        statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, network="AIRBASE").data
 
     @pytest.fixture
     def data_prep(self, data_prep_no_init, data):
@@ -421,8 +421,8 @@ class TestDataPrepJoin:
 
     @pytest.fixture
     def data(self):
-        return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
-                        station_type='background', test='testKWARGS',
+        return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'DEBW107', ['o3', 'temp'],
+                        station_type='background', test='testKWARGS', network="AIRBASE",
                         statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
 
     @pytest.fixture
-- 
GitLab