diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py index dc14146417bd4a6bf30ccbb79537e7920313b077..1bd380d35cae73ba7dee2c2a10214483ab0ed62d 100644 --- a/src/data_handling/__init__.py +++ b/src/data_handling/__init__.py @@ -14,5 +14,5 @@ from .data_preparation_join import DataPrepJoin from .data_generator import DataGenerator from .data_distributor import Distributor from .iterator import KerasIterator, DataCollection -from .advanced_data_handling import DataPreparation +from .advanced_data_handling import DefaultDataPreparation from .data_preparation import StationPrep \ No newline at end of file diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py index d4c4e363dfd1d27ad7ecb2ef34a619b61c20e9fb..e5c5214de9fe03f9e6c5de6e2ddcfbfb9987d052 100644 --- a/src/data_handling/advanced_data_handling.py +++ b/src/data_handling/advanced_data_handling.py @@ -4,7 +4,6 @@ __date__ = '2020-07-08' from src.helpers import to_list, remove_items -from src.data_handling.data_preparation import StationPrep import numpy as np import xarray as xr import pickle @@ -12,10 +11,13 @@ import os import pandas as pd import datetime as dt import shutil +import inspect from typing import Union, List, Tuple import logging from functools import reduce +from src.data_handling.data_preparation import StationPrep + number = Union[float, int] num_or_list = Union[number, List[number]] @@ -45,25 +47,68 @@ class DummyDataSingleStation: # pragma: no cover return self.name -class DataPreparation: +class AbstractDataPreparation: + + _requirements = [] + + def __init__(self, *args, **kwargs): + pass - def __init__(self, id_class, interpolate_dim: str, data_path, neighbors=None, min_length=0, - extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False): + @classmethod + def build(cls, *args, **kwargs): + """Return initialised class.""" + return cls(*args, **kwargs) + + @classmethod + def requirements(cls): + """Return requirements and own arguments without duplicates.""" + return list(set(cls._requirements + cls.own_args())) + + @classmethod + def own_args(cls, *args): + return remove_items(inspect.getfullargspec(cls).args, ["self"] + list(args)) + + def get_X(self, upsampling=False, as_numpy=False): + raise NotImplementedError + + def get_Y(self, upsampling=False, as_numpy=False): + raise NotImplementedError + + +class DefaultDataPreparation(AbstractDataPreparation): + + _requirements = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"]) + + def __init__(self, id_class, data_path, min_length=0, + extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False): + super().__init__() self.id_class = id_class - self.neighbors = to_list(neighbors) if neighbors is not None else [] - self.interpolate_dim = interpolate_dim + self.interpolate_dim = "datetime" self.min_length = min_length self._X = None self._Y = None self._X_extreme = None self._Y_extreme = None self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle") - self._collection = [] - self._create_collection() + self._collection = self._create_collection() self.harmonise_X() self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim) self._store(fresh_store=True) + @classmethod + def build(cls, station, **kwargs): + sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs} + sp = StationPrep(station, **sp_keys) + dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs} + return cls(sp, **dp_args) + + def _create_collection(self): + return [self.id_class] + + @classmethod + def requirements(cls): + return remove_items(super().requirements(), "id_class") + def _reset_data(self): self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None @@ -99,10 +144,6 @@ class DataPreparation: self._reset_data() return X, Y - def _create_collection(self): - for data_class in [self.id_class] + self.neighbors: - self._collection.append(data_class) - def __repr__(self): return ";".join(list(map(lambda x: str(x), self._collection))) @@ -221,7 +262,7 @@ def run_data_prep(): data.get_Y() path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") - data_prep = DataPreparation(DummyDataSingleStation("main_class"), "datetime", path, + data_prep = DataPreparation(DummyDataSingleStation("main_class"), path, neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")], extreme_values=[1., 1.2]) data_prep.get_data(upsampling=False) @@ -238,54 +279,20 @@ def create_data_prep(): interpolate_dim = 'datetime' window_history_size = 7 window_lead_time = 3 - central_station = StationPrep(path, "DEBW011", {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim, + central_station = StationPrep("DEBW011", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {},station_type, network, sampling, target_dim, target_var, interpolate_dim, window_history_size, window_lead_time) - neighbor1 = StationPrep(path, "DEBW013", {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim, + neighbor1 = StationPrep("DEBW013", path, {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, {},station_type, network, sampling, target_dim, target_var, interpolate_dim, window_history_size, window_lead_time) - neighbor2 = StationPrep(path, "DEBW034", {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim, + neighbor2 = StationPrep("DEBW034", path, {'o3': 'dma8eu', 'temp': 'maximum'}, {}, station_type, network, sampling, target_dim, target_var, interpolate_dim, window_history_size, window_lead_time) data_prep = [] - data_prep.append(DataPreparation(central_station, interpolate_dim, path, neighbors=[neighbor1, neighbor2])) - data_prep.append(DataPreparation(neighbor1, interpolate_dim, path, neighbors=[central_station, neighbor2])) - data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station])) + data_prep.append(DataPreparation(central_station, path, neighbors=[neighbor1, neighbor2])) + data_prep.append(DataPreparation(neighbor1, path, neighbors=[central_station, neighbor2])) + data_prep.append(DataPreparation(neighbor2, path, neighbors=[neighbor1, central_station])) return data_prep -class AbstractDataClass: - - def __init__(self): - self._requires = [] - - def __call__(self, *args, **kwargs): - raise NotImplementedError - - @property - def requirements(self): - return self._requires - - @requirements.setter - def requirements(self, value): - self._requires = value - - -class CustomDataClass(AbstractDataClass): - - def __init__(self): - import inspect - super().__init__() - self.sp_keys = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"]) - self.dp_keys = remove_items(inspect.getfullargspec(DataPreparation).args, ["self", "id_class"]) - self.requirements = self.sp_keys + self.dp_keys - - def __call__(self, station, **kwargs): - sp_keys = {k: kwargs[k] for k in self.sp_keys if k in kwargs} - sp_keys["station"] = station - sp = StationPrep(**sp_keys) - dp_args = {k: kwargs[k] for k in self.dp_keys if k in kwargs} - return DataPreparation(sp, **dp_args) - - if __name__ == "__main__": from src.data_handling.data_preparation import StationPrep from src.data_handling.iterator import KerasIterator, DataCollection diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index dadda2c58979ddb2678d366470c3b6d3f0584ee4..09c16c68196b09fc7c1fbe5ef4b2639b684205a4 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -68,7 +68,7 @@ class AbstractStationPrep(): class StationPrep(AbstractStationPrep): - def __init__(self, data_path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var, + def __init__(self, station, data_path, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var, interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs): super().__init__() # path, station, statistics_per_var, transformation, **kwargs) self.station_type = station_type @@ -93,7 +93,7 @@ class StationPrep(AbstractStationPrep): self.label = None self.observation = None - self.transformation = self.setup_transformation(transformation) + self.transformation = None # self.setup_transformation(transformation) self.kwargs = kwargs self.kwargs["overwrite_local_data"] = overwrite_local_data diff --git a/src/data_handling/data_preparation_neighbors.py b/src/data_handling/data_preparation_neighbors.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b5c3c436ef057244544248e6c7deedafbf0c4b --- /dev/null +++ b/src/data_handling/data_preparation_neighbors.py @@ -0,0 +1,67 @@ + +__author__ = 'Lukas Leufen' +__date__ = '2020-07-17' + + +from src.helpers import to_list, remove_items +from src.data_handling.data_preparation import StationPrep +from src.data_handling.advanced_data_handling import AbstractDataPreparation, DefaultDataPreparation +import numpy as np +import xarray as xr +import pickle +import os +import shutil +import inspect + +from typing import Union, List, Tuple +import logging +from functools import reduce + +number = Union[float, int] +num_or_list = Union[number, List[number]] + + +class DataPreparationNeighbors(DefaultDataPreparation): + + def __init__(self, id_class, data_path, neighbors=None, min_length=0, + extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False): + self.neighbors = to_list(neighbors) if neighbors is not None else [] + super().__init__(id_class, data_path, min_length=min_length, extreme_values=extreme_values, + extremes_on_right_tail_only=extremes_on_right_tail_only) + + @classmethod + def build(cls, station, **kwargs): + sp_keys = {k: kwargs[k] for k in cls._requirements if k in kwargs} + sp = StationPrep(station, **sp_keys) + n_list = [] + for neighbor in kwargs.get("neighbors", []): + n_list.append(StationPrep(neighbor, **sp_keys)) + else: + kwargs["neighbors"] = n_list if len(n_list) > 0 else None + dp_args = {k: kwargs[k] for k in cls.own_args("id_class") if k in kwargs} + return cls(sp, **dp_args) + + def _create_collection(self): + return [self.id_class] + self.neighbors + + +if __name__ == "__main__": + + a = DataPreparationNeighbors + requirements = a.requirements() + + kwargs = {"path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"), + "station_type": None, + "network": 'UBA', + "sampling": 'daily', + "target_dim": 'variables', + "target_var": 'o3', + "interpolate_dim": 'datetime', + "window_history_size": 7, + "window_lead_time": 3, + "neighbors": ["DEBW034"], + "data_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata"), + "statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}, + "transformation": None,} + a_inst = a.build("DEBW011", **kwargs) + print(a_inst) diff --git a/src/run.py b/src/run.py index 31127cfe81f0abcea753c9987c78cbeb29a696b4..8a4ade33c0e5b260fafab58e76cf753455077d50 100644 --- a/src/run.py +++ b/src/run.py @@ -29,7 +29,7 @@ def run(stations=None, model=None, batch_size=None, epochs=None, - data_preparation=None): + data_preparation=None,): params = inspect.getfullargspec(DefaultWorkflow).args kwargs = {k: v for k, v in locals().items() if k in params and v is not None} @@ -39,9 +39,4 @@ def run(stations=None, if __name__ == "__main__": - from src.data_handling.advanced_data_handling import CustomDataClass - run(data_preparation=CustomDataClass, statistics_per_var={'o3': 'dma8eu'}, transformation={"scope": "data", - "method": "standardise", - "mean": 50, - "std": 50}, - trainable=False, create_new_model=False) + run(stations=["DEBW013","DEBW025"], statistics_per_var={'o3': 'dma8eu', "temp": "maximum"}, trainable=True, create_new_model=True) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 0a9d711934acd9ff0901085e3e35b43a70c4aca8..15b5c4c6e9d01284d108284365546f1eac9804c1 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -18,8 +18,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST -from src.data_handling import DataPrepJoin -from src.data_handling.advanced_data_handling import CustomDataClass +from src.data_handling.advanced_data_handling import DefaultDataPreparation from src.run_modules.run_environment import RunEnvironment from src.model_modules.model_class import MyLittleModel as VanillaModel @@ -301,9 +300,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") - self._set_param("data_preparation", data_preparation() if data_preparation is not None else None, - default=CustomDataClass()) - assert isinstance(getattr(self.data_store.get("data_preparation"), "requirements"), property) is False + self._set_param("data_preparation", data_preparation, default=DefaultDataPreparation) # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) @@ -350,6 +347,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing # check variables, statistics and target variable self._check_target_var() diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index c5f10e8b9eac0660753c2af858d74d5def31ced9..c6ea67b87fc33a0952a5123754ab3fea62eee488 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -11,8 +11,7 @@ import numpy as np import pandas as pd from src.data_handling import DataGenerator -from src.data_handling import DataCollection, DataPreparation, StationPrep -from src.data_handling.advanced_data_handling import CustomDataClass +from src.data_handling import DataCollection from src.helpers import TimeTracking from src.configuration import path_config from src.helpers.join import EmptyQueryResult @@ -260,10 +259,10 @@ class PreProcessing(RunEnvironment): logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}") collection = DataCollection() valid_stations = [] - kwargs = self.data_store.create_args_dict(data_preparation.requirements, scope=set_name) + kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name) for station in set_stations: try: - dp = data_preparation(station, **kwargs) + dp = data_preparation.build(station, **kwargs) collection.add(dp) valid_stations.append(station) except (AttributeError, EmptyQueryResult): diff --git a/src/workflows/abstract_workflow.py b/src/workflows/abstract_workflow.py index 5d4e62c8a2e409e865f43412a6757a9cb4e4b1f3..350008eace4598567779228b1302a83c7375fd06 100644 --- a/src/workflows/abstract_workflow.py +++ b/src/workflows/abstract_workflow.py @@ -26,4 +26,4 @@ class Workflow: """Run workflow embedded in a run environment and according to the stage's ordering.""" with RunEnvironment(): for stage, kwargs in self._registry.items(): - stage(**kwargs) \ No newline at end of file + stage(**kwargs)