diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 19a94fbb9dbbc8f382a225c852f34971a98395b8..db3b044c4909ce7acf28d5a2b633e835fbc58915 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -2,7 +2,7 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-11-07' import os -from typing import Union, List, Tuple, Any +from typing import Union, List, Tuple, Any, Dict import keras import xarray as xr @@ -11,6 +11,7 @@ import logging from src import helpers from src.data_handling.data_preparation import DataPrep +from src.join import EmptyQueryResult class DataGenerator(keras.utils.Sequence): @@ -25,7 +26,7 @@ class DataGenerator(keras.utils.Sequence): def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, - window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): + window_lead_time: int = 4, transformation: Dict = None, **kwargs): self.data_path = os.path.abspath(data_path) self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): @@ -41,8 +42,8 @@ class DataGenerator(keras.utils.Sequence): self.limit_nan_fill = limit_nan_fill self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.transform_method = transform_method self.kwargs = kwargs + self.transformation = self.setup_transformation(transformation) def __repr__(self): """ @@ -94,6 +95,44 @@ class DataGenerator(keras.utils.Sequence): data = self.get_data_generator(key=item) return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window") + def setup_transformation(self, transformation): + if transformation is None: + return + scope = transformation.get("scope", "station") + method = transformation.get("method", "standardise") + mean = transformation.get("mean", None) + std = transformation.get("std", None) + if scope == "data": + if mean == "accurate": + mean, std = self.calculate_accurate_transformation(method) + elif mean == "estimate": + mean, std = self.calculate_estimated_transformation(method) + else: + mean = mean + transformation["mean"] = mean + transformation["std"] = std + return transformation + + def calculate_accurate_transformation(self, method): + mean = None + std = None + return mean, std + + def calculate_estimated_transformation(self, method): + mean = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"]) + std = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"]) + for station in self.stations: + try: + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + data.transform("datetime", method=method) + mean = mean.combine_first(data.mean) + std = std.combine_first(data.std) + data.transform("datetime", method=method, inverse=True) + except EmptyQueryResult: + continue + return mean.mean("Stations") if mean.shape[1] > 0 else "hi", std.mean("Stations") if std.shape[1] > 0 else None + def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and @@ -113,7 +152,7 @@ class DataGenerator(keras.utils.Sequence): data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, **self.kwargs) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", method=self.transform_method) + data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) data.make_history_window(self.interpolate_dim, self.window_history_size) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.history_label_nan_remove(self.interpolate_dim) diff --git a/src/helpers.py b/src/helpers.py index c33684508b7b36cbebc3cbf3ac826d1779f9df50..e1496c3b232db29194878892647c59581b6a70a3 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -190,3 +190,8 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce """ multiplier = 10. ** decimals return round_type(number * multiplier) / multiplier + + +def dict_pop(dict: Dict, pop_keys): + pop_keys = to_list(pop_keys) + return {k: v for k, v in dict.items() if k not in pop_keys} diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 48f7c13e51622d7d52405b73c0a6f57537b5b476..9c208227741b23da2ce37de08f891ec61aefab34 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None): + create_new_model=None, transformation=None): # create run framework super().__init__() @@ -77,6 +77,8 @@ class ExperimentSetup(RunEnvironment): self._set_param("window_history_size", window_history_size, default=13) self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing") self._set_param("sampling", sampling) + self._set_param("transformation", transformation, default={"scope": "data", "method": "standardise", + "mean": "estimate"}) # target self._set_param("target_var", target_var, default="o3") diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 4660a8116b6d0b860a7d0d50b92cee5e0deb77d8..62932edcd0f8a4b19efe96e173924942e5e41a2f 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -12,7 +12,7 @@ from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", - "station_type", "overwrite_local_data", "start", "end", "sampling"] + "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"] class PreProcessing(RunEnvironment): @@ -36,10 +36,15 @@ class PreProcessing(RunEnvironment): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) + self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() self.report_pre_processing() + def calculate_transformation(self, args: Dict, kwargs: Dict, all_stations: List[str], load_tmp): + + pass + def report_pre_processing(self): logging.debug(20 * '##') n_train = len(self.data_store.get('generator', 'general.train'))