diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 1de0ab2092b46dfca6281963b260b1fc6bc65387..f259c4030e1293bf565c87f65fe5eff3ee1a31b5 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -7,6 +7,8 @@ from src.data_handling.data_preparation import DataPrep import os from typing import Union, List, Tuple import xarray as xr +import pickle +import logging class DataGenerator(keras.utils.Sequence): @@ -23,6 +25,9 @@ class DataGenerator(keras.utils.Sequence): interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): self.data_path = os.path.abspath(data_path) + self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") + if not os.path.exists(self.data_path_tmp): + os.makedirs(self.data_path_tmp) self.network = network self.stations = helpers.to_list(stations) self.variables = variables @@ -88,7 +93,7 @@ class DataGenerator(keras.utils.Sequence): return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") - def get_data_generator(self, key: Union[str, int] = None) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. @@ -96,13 +101,32 @@ class DataGenerator(keras.utils.Sequence): :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) - data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", method=self.transform_method) - data.make_history_window(self.interpolate_dim, self.window_history_size) - data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - data.history_label_nan_remove(self.interpolate_dim) + try: + if not load_tmp: + raise FileNotFoundError + data = self._load_pickle_data(station, self.variables) + except FileNotFoundError: + logging.info(f"load not pickle data for {station}") + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) + data.transform("datetime", method=self.transform_method) + data.make_history_window(self.interpolate_dim, self.window_history_size) + data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) + data.history_label_nan_remove(self.interpolate_dim) + self._save_pickle_data(data) + return data + + def _save_pickle_data(self, data): + file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle") + with open(file, "wb") as f: + pickle.dump(data, f) + logging.debug(f"save pickle data to {file}") + + def _load_pickle_data(self, station, variables): + file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle") + data = pickle.load(open(file, "rb")) + logging.debug(f"load pickle data from {file}") return data def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str: diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 2a4632d515a36a77b01a09a539da4f51ecd3e07a..2f8b2777b8e7edd9dd2be3e425d3d2f18d81a4b2 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) + valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() self.report_pre_processing() @@ -97,7 +97,7 @@ class PreProcessing(RunEnvironment): self.data_store.set("generator", data_set, scope) @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. @@ -118,9 +118,10 @@ class PreProcessing(RunEnvironment): for station in all_stations: t_inner.run() try: - (history, label) = data_gen[station] + # (history, label) = data_gen[station] + data = data_gen.get_data_generator(key=station, load_tmp=load_tmp) valid_stations.append(station) - logging.debug(f"{station}: history_shape = {history.shape}") + logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") except (AttributeError, EmptyQueryResult): continue