diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 85c25030a5f15ba2372b2a242acca3496dfebc4a..897d85c121fbe4b64b81bfc504209c7a3069f9f6 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -99,19 +99,21 @@ class DataGenerator(keras.utils.Sequence): def setup_transformation(self, transformation): if transformation is None: return + transformation = transformation.copy() scope = transformation.get("scope", "station") method = transformation.get("method", "standardise") mean = transformation.get("mean", None) std = transformation.get("std", None) - if scope == "data": - if mean == "accurate": - mean, std = self.calculate_accurate_transformation(method) - elif mean == "estimate": - mean, std = self.calculate_estimated_transformation(method) - else: - mean = mean - transformation["mean"] = mean - transformation["std"] = std + if isinstance(mean, str): + if scope == "data": + if mean == "accurate": + mean, std = self.calculate_accurate_transformation(method) + elif mean == "estimate": + mean, std = self.calculate_estimated_transformation(method) + else: + mean = mean + transformation["mean"] = mean + transformation["std"] = std return transformation def calculate_accurate_transformation(self, method): @@ -150,18 +152,20 @@ class DataGenerator(keras.utils.Sequence): continue return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None - def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, + save_local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. :param key: station key to choose the data generator. - :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from - tmp pickle file to save computational time (but of course more disk space required). + :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data + from tmp pickle file to save computational time (but of course more disk space required). + :param save_local_tmp_storage: save processed data as temporal file locally (default True) :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) try: - if not local_tmp_storage: + if not load_local_tmp_storage: raise FileNotFoundError data = self._load_pickle_data(station, self.variables) except FileNotFoundError: @@ -169,11 +173,13 @@ class DataGenerator(keras.utils.Sequence): data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, **self.kwargs) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) + if self.transformation is not None: + data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) data.make_history_window(self.interpolate_dim, self.window_history_size) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.history_label_nan_remove(self.interpolate_dim) - self._save_pickle_data(data) + if save_local_tmp_storage: + self._save_pickle_data(data) return data def _save_pickle_data(self, data: Any): diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 9c208227741b23da2ce37de08f891ec61aefab34..f039cd08fd46d926704dc216678b67c0f8878006 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -79,6 +79,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default={"scope": "data", "method": "standardise", "mean": "estimate"}) + self._set_param("transformation", None, scope="general.preprocessing") # target self._set_param("target_var", target_var, default="o3") diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 62932edcd0f8a4b19efe96e173924942e5e41a2f..44de71171377076e887c099ec1229391daae32d8 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -35,7 +35,7 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) + valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False, save_tmp=False) self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() @@ -94,14 +94,16 @@ class PreProcessing(RunEnvironment): else: set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") - set_stations = self.check_valid_stations(args, kwargs, set_stations) + set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False) self.data_store.set("stations", set_stations, scope) set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) self.data_store.set("generator", data_set, scope) + if set_name == "train": + self.data_store.set("transformation", data_set.transformation, "general") @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. @@ -123,7 +125,8 @@ class PreProcessing(RunEnvironment): t_inner.run() try: # (history, label) = data_gen[station] - data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp) + data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp, + save_local_tmp_storage=save_tmp) valid_stations.append(station) logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 142acd166604951352ad6686548c2cb76f609ce0..306dfa3079c306e46e05cc5b8fe2361acdcf281f 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -104,7 +104,7 @@ class TestDataGenerator: if os.path.exists(file): os.remove(file) assert not os.path.exists(file) - assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep) + assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep) t = os.stat(file).st_ctime assert os.path.exists(file) assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)