From 47c9e0d30391a34eb5c097925deb0a3157282f4d Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 3 Mar 2020 10:18:51 +0100 Subject: [PATCH] do not calculate transformation if mean is given, switch to save local tmp data or not --- src/data_handling/data_generator.py | 36 +++++++++++-------- src/run_modules/experiment_setup.py | 1 + src/run_modules/pre_processing.py | 11 +++--- .../test_data_handling/test_data_generator.py | 2 +- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 85c25030..897d85c1 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -99,19 +99,21 @@ class DataGenerator(keras.utils.Sequence): def setup_transformation(self, transformation): if transformation is None: return + transformation = transformation.copy() scope = transformation.get("scope", "station") method = transformation.get("method", "standardise") mean = transformation.get("mean", None) std = transformation.get("std", None) - if scope == "data": - if mean == "accurate": - mean, std = self.calculate_accurate_transformation(method) - elif mean == "estimate": - mean, std = self.calculate_estimated_transformation(method) - else: - mean = mean - transformation["mean"] = mean - transformation["std"] = std + if isinstance(mean, str): + if scope == "data": + if mean == "accurate": + mean, std = self.calculate_accurate_transformation(method) + elif mean == "estimate": + mean, std = self.calculate_estimated_transformation(method) + else: + mean = mean + transformation["mean"] = mean + transformation["std"] = std return transformation def calculate_accurate_transformation(self, method): @@ -150,18 +152,20 @@ class DataGenerator(keras.utils.Sequence): continue return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None - def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, + save_local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. :param key: station key to choose the data generator. - :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from - tmp pickle file to save computational time (but of course more disk space required). + :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data + from tmp pickle file to save computational time (but of course more disk space required). + :param save_local_tmp_storage: save processed data as temporal file locally (default True) :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) try: - if not local_tmp_storage: + if not load_local_tmp_storage: raise FileNotFoundError data = self._load_pickle_data(station, self.variables) except FileNotFoundError: @@ -169,11 +173,13 @@ class DataGenerator(keras.utils.Sequence): data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, **self.kwargs) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) + if self.transformation is not None: + data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) data.make_history_window(self.interpolate_dim, self.window_history_size) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.history_label_nan_remove(self.interpolate_dim) - self._save_pickle_data(data) + if save_local_tmp_storage: + self._save_pickle_data(data) return data def _save_pickle_data(self, data: Any): diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 9c208227..f039cd08 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -79,6 +79,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default={"scope": "data", "method": "standardise", "mean": "estimate"}) + self._set_param("transformation", None, scope="general.preprocessing") # target self._set_param("target_var", target_var, default="o3") diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 62932edc..44de7117 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -35,7 +35,7 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) + valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False, save_tmp=False) self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() @@ -94,14 +94,16 @@ class PreProcessing(RunEnvironment): else: set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") - set_stations = self.check_valid_stations(args, kwargs, set_stations) + set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False) self.data_store.set("stations", set_stations, scope) set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) self.data_store.set("generator", data_set, scope) + if set_name == "train": + self.data_store.set("transformation", data_set.transformation, "general") @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. @@ -123,7 +125,8 @@ class PreProcessing(RunEnvironment): t_inner.run() try: # (history, label) = data_gen[station] - data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp) + data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp, + save_local_tmp_storage=save_tmp) valid_stations.append(station) logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 142acd16..306dfa30 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -104,7 +104,7 @@ class TestDataGenerator: if os.path.exists(file): os.remove(file) assert not os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep) t = os.stat(file).st_ctime assert os.path.exists(file) assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) -- GitLab